From 1fa65d89ca5dc64756f7933d7cc3f524e4627dce Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Mon, 22 Mar 2004 04:15:03 +0000 Subject: Adjusted some details, got rid of old files, added some new. --- src/http.lua | 41 +++++++++++++++++++++++------------------ src/ltn12.lua | 13 +++++++++++++ src/smtp.lua | 43 ++++++++++++++++++++++--------------------- 3 files changed, 58 insertions(+), 39 deletions(-) (limited to 'src') diff --git a/src/http.lua b/src/http.lua index a10cf50..ab166e3 100644 --- a/src/http.lua +++ b/src/http.lua @@ -2,7 +2,7 @@ -- HTTP/1.1 client support for the Lua language. -- LuaSocket toolkit. -- Author: Diego Nehab --- Conforming to: RFC 2616, LTN7 +-- Conforming to RFC 2616 -- RCS ID: $Id$ ----------------------------------------------------------------------------- -- make sure LuaSocket is loaded @@ -39,21 +39,18 @@ local function third(a, b, c) return c end -local function shift(a, b, c, d) - return c, d -end - --- resquest_p forward declaration -local request_p - -local function receive_headers(sock, headers) - local line, name, value +local function receive_headers(reqt, respt) + local headers = {} + local sock = respt.tmp.sock + local line, name, value, _ + -- store results + respt.headers = headers -- get first line line = socket.try(sock:receive()) -- headers go until a blank line is found while line ~= "" do -- get field-name and value - name, value = shift(string.find(line, "^(.-):%s*(.*)")) + _, _, name, value = string.find(line, "^(.-):%s*(.*)") assert(name and value, "malformed reponse headers") name = string.lower(name) -- get next line (value might be folded) @@ -100,7 +97,10 @@ local function receive_body_bychunks(sock, sink) -- let callback know we are done hand(sink, nil) -- servers shouldn't send trailer headers, but who trusts them? - receive_headers(sock, {}) + local line = socket.try(sock:receive()) + while line ~= "" do + line = socket.try(sock:receive()) + end end local function receive_body_bylength(sock, length, sink) @@ -245,7 +245,7 @@ local function open(reqt, respt) socket.try(sock:connect(host, port)) end -function adjust_headers(reqt, respt) +local function adjust_headers(reqt, respt) local lower = {} local headers = reqt.headers or {} -- set default headers @@ -261,7 +261,7 @@ function adjust_headers(reqt, respt) respt.tmp.headers = lower end -function parse_url(reqt, respt) +local function parse_url(reqt, respt) -- parse url with default fields local parsed = socket.url.parse(reqt.url, { host = "", @@ -280,11 +280,16 @@ function parse_url(reqt, respt) respt.tmp.parsed = parsed end +-- forward declaration +local request_p + local function should_authorize(reqt, respt) -- if there has been an authorization attempt, it must have failed if reqt.headers and reqt.headers["authorization"] then return nil end - -- if we don't have authorization information, we can't retry - return respt.tmp.parsed.user and respt.tmp.parsed.password + -- if last attempt didn't fail due to lack of authentication, + -- or we don't have authorization information, we can't retry + return respt.code == 401 and + respt.tmp.parsed.user and respt.tmp.parsed.password end local function clone(headers) @@ -338,14 +343,14 @@ local function redirect(reqt, respt) if respt.headers then respt.headers.location = redirt.url end end +-- execute a request of through an exception function request_p(reqt, respt) parse_url(reqt, respt) adjust_headers(reqt, respt) open(reqt, respt) send_request(reqt, respt) receive_status(reqt, respt) - respt.headers = {} - receive_headers(respt.tmp.sock, respt.headers) + receive_headers(reqt, respt) if should_redirect(reqt, respt) then respt.tmp.sock:close() redirect(reqt, respt) diff --git a/src/ltn12.lua b/src/ltn12.lua index dc49d80..ed3449b 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -22,6 +22,7 @@ end -- returns a high level filter that cycles a cycles a low-level filter function filter.cycle(low, ctx, extra) + if type(low) ~= 'function' then error('invalid low-level filter', 2) end return function(chunk) local ret ret, ctx = low(ctx, chunk, extra) @@ -31,6 +32,8 @@ end -- chains two filters together local function chain2(f1, f2) + if type(f1) ~= 'function' then error('invalid filter', 2) end + if type(f2) ~= 'function' then error('invalid filter', 2) end return function(chunk) return f2(f1(chunk)) end @@ -40,6 +43,7 @@ end function filter.chain(...) local f = arg[1] for i = 2, table.getn(arg) do + if type(arg[i]) ~= 'function' then error('invalid filter', 2) end f = chain2(f, arg[i]) end return f @@ -74,6 +78,7 @@ end -- turns a fancy source into a simple source function source.simplify(src) + if type(src) ~= 'function' then error('invalid source', 2) end return function() local chunk, err_or_new = src() src = err_or_new or src @@ -97,6 +102,7 @@ end -- creates rewindable source function source.rewind(src) + if type(src) ~= 'function' then error('invalid source', 2) end local t = {} return function(chunk) if not chunk then @@ -111,6 +117,8 @@ end -- chains a source with a filter function source.chain(src, f) + if type(src) ~= 'function' then error('invalid source', 2) end + if type(f) ~= 'function' then error('invalid filter', 2) end local co = coroutine.create(function() while true do local chunk, err = src() @@ -157,6 +165,7 @@ end -- turns a fancy sink into a simple sink function sink.simplify(snk) + if type(snk) ~= 'function' then error('invalid sink', 2) end return function(chunk, err) local ret, err_or_new = snk(chunk, err) if not ret then return nil, err_or_new end @@ -195,6 +204,8 @@ end -- chains a sink with a filter function sink.chain(f, snk) + if type(snk) ~= 'function' then error('invalid sink', 2) end + if type(f) ~= 'function' then error('invalid filter', 2) end return function(chunk, err) local filtered = f(chunk) local done = chunk and "" @@ -209,6 +220,8 @@ end -- pumps all data from a source to a sink function pump(src, snk) + if type(src) ~= 'function' then error('invalid source', 2) end + if type(snk) ~= 'function' then error('invalid sink', 2) end while true do local chunk, src_err = src() local ret, snk_err = snk(chunk, src_err) diff --git a/src/smtp.lua b/src/smtp.lua index c823c97..ed8bd15 100644 --- a/src/smtp.lua +++ b/src/smtp.lua @@ -20,16 +20,17 @@ DOMAIN = os.getenv("SERVER_NAME") or "localhost" -- default time zone (means we don't know) ZONE = "-0000" -function stuff() - return ltn12.filter.cycle(dot, 2) -end - local function shift(a, b, c) return b, c end +-- high level stuffing filter +function stuff() + return ltn12.filter.cycle(dot, 2) +end + -- send message or throw an exception -function psend(control, mailt) +local function send_p(control, mailt) socket.try(control:check("2..")) socket.try(control:command("EHLO", mailt.domain or DOMAIN)) socket.try(control:check("2..")) @@ -61,11 +62,11 @@ local function newboundary() math.random(0, 99999), seqno) end --- sendmessage forward declaration -local sendmessage +-- send_message forward declaration +local send_message -- yield multipart message body from a multipart message table -local function sendmultipart(mesgt) +local function send_multipart(mesgt) local bd = newboundary() -- define boundary and finish headers coroutine.yield('content-type: multipart/mixed; boundary="' .. @@ -75,7 +76,7 @@ local function sendmultipart(mesgt) -- send each part separated by a boundary for i, m in ipairs(mesgt.body) do coroutine.yield("\r\n--" .. bd .. "\r\n") - sendmessage(m) + send_message(m) end -- send last boundary coroutine.yield("\r\n--" .. bd .. "--\r\n\r\n") @@ -84,7 +85,7 @@ local function sendmultipart(mesgt) end -- yield message body from a source -local function sendsource(mesgt) +local function send_source(mesgt) -- set content-type if user didn't override if not mesgt.headers or not mesgt.headers["content-type"] then coroutine.yield('content-type: text/plain; charset="iso-8859-1"\r\n') @@ -101,7 +102,7 @@ local function sendsource(mesgt) end -- yield message body from a string -local function sendstring(mesgt) +local function send_string(mesgt) -- set content-type if user didn't override if not mesgt.headers or not mesgt.headers["content-type"] then coroutine.yield('content-type: text/plain; charset="iso-8859-1"\r\n') @@ -114,7 +115,7 @@ local function sendstring(mesgt) end -- yield the headers one by one -local function sendheaders(mesgt) +local function send_headers(mesgt) if mesgt.headers then for i,v in pairs(mesgt.headers) do coroutine.yield(i .. ':' .. v .. "\r\n") @@ -123,15 +124,15 @@ local function sendheaders(mesgt) end -- message source -function sendmessage(mesgt) - sendheaders(mesgt) - if type(mesgt.body) == "table" then sendmultipart(mesgt) - elseif type(mesgt.body) == "function" then sendsource(mesgt) - else sendstring(mesgt) end +function send_message(mesgt) + send_headers(mesgt) + if type(mesgt.body) == "table" then send_multipart(mesgt) + elseif type(mesgt.body) == "function" then send_source(mesgt) + else send_string(mesgt) end end -- set defaul headers -local function adjustheaders(mesgt) +local function adjust_headers(mesgt) mesgt.headers = mesgt.headers or {} mesgt.headers["mime-version"] = "1.0" mesgt.headers["date"] = mesgt.headers["date"] or @@ -140,16 +141,16 @@ local function adjustheaders(mesgt) end function message(mesgt) - adjustheaders(mesgt) + adjust_headers(mesgt) -- create and return message source - local co = coroutine.create(function() sendmessage(mesgt) end) + local co = coroutine.create(function() send_message(mesgt) end) return function() return shift(coroutine.resume(co)) end end function send(mailt) local c, e = socket.tp.connect(mailt.server or SERVER, mailt.port or PORT) if not c then return nil, e end - local s, e = pcall(psend, c, mailt) + local s, e = pcall(send_p, c, mailt) c:close() if s then return true else return nil, e end -- cgit v1.2.3-55-g6feb