diff options
Diffstat (limited to 'src/http.lua')
| -rw-r--r-- | src/http.lua | 351 |
1 files changed, 147 insertions, 204 deletions
diff --git a/src/http.lua b/src/http.lua index ebe6b54..129b562 100644 --- a/src/http.lua +++ b/src/http.lua | |||
| @@ -7,7 +7,7 @@ | |||
| 7 | ----------------------------------------------------------------------------- | 7 | ----------------------------------------------------------------------------- |
| 8 | 8 | ||
| 9 | ----------------------------------------------------------------------------- | 9 | ----------------------------------------------------------------------------- |
| 10 | -- Load other required modules | 10 | -- Load required modules |
| 11 | ------------------------------------------------------------------------------- | 11 | ------------------------------------------------------------------------------- |
| 12 | local socket = require("socket") | 12 | local socket = require("socket") |
| 13 | local ltn12 = require("ltn12") | 13 | local ltn12 = require("ltn12") |
| @@ -17,42 +17,68 @@ local url = require("url") | |||
| 17 | ----------------------------------------------------------------------------- | 17 | ----------------------------------------------------------------------------- |
| 18 | -- Setup namespace | 18 | -- Setup namespace |
| 19 | ------------------------------------------------------------------------------- | 19 | ------------------------------------------------------------------------------- |
| 20 | http = {} | 20 | _LOADED["http"] = getfenv(1) |
| 21 | -- make all module globals fall into namespace | ||
| 22 | setmetatable(http, { __index = _G }) | ||
| 23 | setfenv(1, http) | ||
| 24 | 21 | ||
| 25 | ----------------------------------------------------------------------------- | 22 | ----------------------------------------------------------------------------- |
| 26 | -- Program constants | 23 | -- Program constants |
| 27 | ----------------------------------------------------------------------------- | 24 | ----------------------------------------------------------------------------- |
| 28 | -- connection timeout in seconds | 25 | -- connection timeout in seconds |
| 29 | TIMEOUT = 60 | 26 | TIMEOUT = 4 |
| 30 | -- default port for document retrieval | 27 | -- default port for document retrieval |
| 31 | PORT = 80 | 28 | PORT = 80 |
| 32 | -- user agent field sent in request | 29 | -- user agent field sent in request |
| 33 | USERAGENT = socket.version | 30 | USERAGENT = socket.VERSION |
| 34 | -- block size used in transfers | 31 | -- block size used in transfers |
| 35 | BLOCKSIZE = 2048 | 32 | BLOCKSIZE = 2048 |
| 36 | 33 | ||
| 37 | ----------------------------------------------------------------------------- | 34 | ----------------------------------------------------------------------------- |
| 38 | -- Function return value selectors | 35 | -- Low level HTTP API |
| 39 | ----------------------------------------------------------------------------- | 36 | ----------------------------------------------------------------------------- |
| 40 | local function second(a, b) | 37 | local metat = { __index = {} } |
| 41 | return b | 38 | |
| 39 | function open(host, port) | ||
| 40 | local con = socket.try(socket.tcp()) | ||
| 41 | socket.try(con:settimeout(TIMEOUT)) | ||
| 42 | port = port or PORT | ||
| 43 | socket.try(con:connect(host, port)) | ||
| 44 | return setmetatable({ con = con }, metat) | ||
| 45 | end | ||
| 46 | |||
| 47 | function metat.__index:sendrequestline(method, uri) | ||
| 48 | local reqline = string.format("%s %s HTTP/1.1\r\n", method or "GET", uri) | ||
| 49 | return socket.try(self.con:send(reqline)) | ||
| 42 | end | 50 | end |
| 43 | 51 | ||
| 44 | local function third(a, b, c) | 52 | function metat.__index:sendheaders(headers) |
| 45 | return c | 53 | for i, v in pairs(headers) do |
| 54 | socket.try(self.con:send(i .. ": " .. v .. "\r\n")) | ||
| 55 | end | ||
| 56 | -- mark end of request headers | ||
| 57 | socket.try(self.con:send("\r\n")) | ||
| 58 | return 1 | ||
| 46 | end | 59 | end |
| 47 | 60 | ||
| 48 | local function receive_headers(reqt, respt, tmp) | 61 | function metat.__index:sendbody(headers, source, step) |
| 49 | local sock = tmp.sock | 62 | source = source or ltn12.source.empty() |
| 63 | step = step or ltn12.pump.step | ||
| 64 | -- if we don't know the size in advance, send chunked and hope for the best | ||
| 65 | local mode | ||
| 66 | if headers["content-length"] then mode = "keep-open" | ||
| 67 | else mode = "http-chunked" end | ||
| 68 | return socket.try(ltn12.pump.all(source, socket.sink(mode, self.con), step)) | ||
| 69 | end | ||
| 70 | |||
| 71 | function metat.__index:receivestatusline() | ||
| 72 | local status = socket.try(self.con:receive()) | ||
| 73 | local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) | ||
| 74 | return socket.try(tonumber(code), status) | ||
| 75 | end | ||
| 76 | |||
| 77 | function metat.__index:receiveheaders() | ||
| 50 | local line, name, value | 78 | local line, name, value |
| 51 | local headers = {} | 79 | local headers = {} |
| 52 | -- store results | ||
| 53 | respt.headers = headers | ||
| 54 | -- get first line | 80 | -- get first line |
| 55 | line = socket.try(sock:receive()) | 81 | line = socket.try(self.con:receive()) |
| 56 | -- headers go until a blank line is found | 82 | -- headers go until a blank line is found |
| 57 | while line ~= "" do | 83 | while line ~= "" do |
| 58 | -- get field-name and value | 84 | -- get field-name and value |
| @@ -60,189 +86,137 @@ local function receive_headers(reqt, respt, tmp) | |||
| 60 | socket.try(name and value, "malformed reponse headers") | 86 | socket.try(name and value, "malformed reponse headers") |
| 61 | name = string.lower(name) | 87 | name = string.lower(name) |
| 62 | -- get next line (value might be folded) | 88 | -- get next line (value might be folded) |
| 63 | line = socket.try(sock:receive()) | 89 | line = socket.try(self.con:receive()) |
| 64 | -- unfold any folded values | 90 | -- unfold any folded values |
| 65 | while string.find(line, "^%s") do | 91 | while string.find(line, "^%s") do |
| 66 | value = value .. line | 92 | value = value .. line |
| 67 | line = socket.try(sock:receive()) | 93 | line = socket.try(self.con:receive()) |
| 68 | end | 94 | end |
| 69 | -- save pair in table | 95 | -- save pair in table |
| 70 | if headers[name] then headers[name] = headers[name] .. ", " .. value | 96 | if headers[name] then headers[name] = headers[name] .. ", " .. value |
| 71 | else headers[name] = value end | 97 | else headers[name] = value end |
| 72 | end | 98 | end |
| 99 | return headers | ||
| 73 | end | 100 | end |
| 74 | 101 | ||
| 75 | local function receive_body(reqt, respt, tmp) | 102 | function metat.__index:receivebody(headers, sink, step) |
| 76 | local sink = reqt.sink or ltn12.sink.null() | 103 | sink = sink or ltn12.sink.null() |
| 77 | local step = reqt.step or ltn12.pump.step | 104 | step = step or ltn12.pump.step |
| 78 | local source | 105 | local length = tonumber(headers["content-length"]) |
| 79 | local te = respt.headers["transfer-encoding"] | 106 | local TE = headers["transfer-encoding"] |
| 80 | if te and te ~= "identity" then | 107 | local mode |
| 81 | -- get by chunked transfer-coding of message body | 108 | if TE and TE ~= "identity" then mode = "http-chunked" |
| 82 | source = socket.source("http-chunked", tmp.sock) | 109 | elseif tonumber(headers["content-length"]) then mode = "by-length" |
| 83 | elseif tonumber(respt.headers["content-length"]) then | 110 | else mode = "default" end |
| 84 | -- get by content-length | 111 | return socket.try(ltn12.pump.all(socket.source(mode, self.con, length), |
| 85 | local length = tonumber(respt.headers["content-length"]) | 112 | sink, step)) |
| 86 | source = socket.source("by-length", tmp.sock, length) | ||
| 87 | else | ||
| 88 | -- get it all until connection closes | ||
| 89 | source = socket.source(tmp.sock) | ||
| 90 | end | ||
| 91 | socket.try(ltn12.pump.all(source, sink, step)) | ||
| 92 | end | 113 | end |
| 93 | 114 | ||
| 94 | local function send_headers(sock, headers) | 115 | function metat.__index:close() |
| 95 | -- send request headers | 116 | return self.con:close() |
| 96 | for i, v in pairs(headers) do | ||
| 97 | socket.try(sock:send(i .. ": " .. v .. "\r\n")) | ||
| 98 | end | ||
| 99 | -- mark end of request headers | ||
| 100 | socket.try(sock:send("\r\n")) | ||
| 101 | end | 117 | end |
| 102 | 118 | ||
| 103 | local function should_receive_body(reqt, respt, tmp) | 119 | ----------------------------------------------------------------------------- |
| 104 | if reqt.method == "HEAD" then return nil end | 120 | -- High level HTTP API |
| 105 | if respt.code == 204 or respt.code == 304 then return nil end | 121 | ----------------------------------------------------------------------------- |
| 106 | if respt.code >= 100 and respt.code < 200 then return nil end | 122 | local function uri(reqt) |
| 107 | return 1 | 123 | local u = reqt |
| 108 | end | 124 | if not reqt.proxy and not PROXY then |
| 109 | 125 | u = { | |
| 110 | local function receive_status(reqt, respt, tmp) | 126 | path = reqt.path, |
| 111 | local status = socket.try(tmp.sock:receive()) | 127 | params = reqt.params, |
| 112 | local code = third(string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) | 128 | query = reqt.query, |
| 113 | -- store results | 129 | fragment = reqt.fragment |
| 114 | respt.code, respt.status = tonumber(code), status | ||
| 115 | end | ||
| 116 | |||
| 117 | local function request_uri(reqt, respt, tmp) | ||
| 118 | local u = tmp.parsed | ||
| 119 | if not reqt.proxy then | ||
| 120 | local parsed = tmp.parsed | ||
| 121 | u = { | ||
| 122 | path = parsed.path, | ||
| 123 | params = parsed.params, | ||
| 124 | query = parsed.query, | ||
| 125 | fragment = parsed.fragment | ||
| 126 | } | 130 | } |
| 127 | end | 131 | end |
| 128 | return url.build(u) | 132 | return url.build(u) |
| 129 | end | 133 | end |
| 130 | 134 | ||
| 131 | local function send_request(reqt, respt, tmp) | 135 | local function adjustheaders(headers, host) |
| 132 | local uri = request_uri(reqt, respt, tmp) | ||
| 133 | local headers = tmp.headers | ||
| 134 | local step = reqt.step or ltn12.pump.step | ||
| 135 | -- send request line | ||
| 136 | socket.try(tmp.sock:send((reqt.method or "GET") | ||
| 137 | .. " " .. uri .. " HTTP/1.1\r\n")) | ||
| 138 | if reqt.source and not headers["content-length"] then | ||
| 139 | headers["transfer-encoding"] = "chunked" | ||
| 140 | end | ||
| 141 | send_headers(tmp.sock, headers) | ||
| 142 | -- send request message body, if any | ||
| 143 | if not reqt.source then return end | ||
| 144 | if headers["content-length"] then | ||
| 145 | socket.try(ltn12.pump.all(reqt.source, | ||
| 146 | socket.sink(tmp.sock), step)) | ||
| 147 | else | ||
| 148 | socket.try(ltn12.pump.all(reqt.source, | ||
| 149 | socket.sink("http-chunked", tmp.sock), step)) | ||
| 150 | end | ||
| 151 | end | ||
| 152 | |||
| 153 | local function open(reqt, respt, tmp) | ||
| 154 | local proxy = reqt.proxy or PROXY | ||
| 155 | local host, port | ||
| 156 | if proxy then | ||
| 157 | local pproxy = url.parse(proxy) | ||
| 158 | socket.try(pproxy.port and pproxy.host, "invalid proxy") | ||
| 159 | host, port = pproxy.host, pproxy.port | ||
| 160 | else | ||
| 161 | host, port = tmp.parsed.host, tmp.parsed.port | ||
| 162 | end | ||
| 163 | -- store results | ||
| 164 | tmp.sock = socket.try(socket.tcp()) | ||
| 165 | socket.try(tmp.sock:settimeout(reqt.timeout or TIMEOUT)) | ||
| 166 | socket.try(tmp.sock:connect(host, port)) | ||
| 167 | end | ||
| 168 | |||
| 169 | local function adjust_headers(reqt, respt, tmp) | ||
| 170 | local lower = {} | 136 | local lower = {} |
| 171 | -- override with user values | 137 | -- override with user values |
| 172 | for i,v in (reqt.headers or lower) do | 138 | for i,v in (headers or lower) do |
| 173 | lower[string.lower(i)] = v | 139 | lower[string.lower(i)] = v |
| 174 | end | 140 | end |
| 175 | lower["user-agent"] = lower["user-agent"] or USERAGENT | 141 | lower["user-agent"] = lower["user-agent"] or USERAGENT |
| 176 | -- these cannot be overriden | 142 | -- these cannot be overriden |
| 177 | lower["host"] = tmp.parsed.host | 143 | lower["host"] = host |
| 178 | lower["connection"] = "close" | 144 | return lower |
| 179 | -- store results | ||
| 180 | tmp.headers = lower | ||
| 181 | end | 145 | end |
| 182 | 146 | ||
| 183 | local function parse_url(reqt, respt, tmp) | 147 | local function adjustrequest(reqt) |
| 184 | -- parse url with default fields | 148 | -- parse url with default fields |
| 185 | local parsed = url.parse(reqt.url, { | 149 | local parsed = url.parse(reqt.url, { |
| 186 | host = "", | 150 | host = "", |
| 187 | port = PORT, | 151 | port = PORT, |
| 188 | path ="/", | 152 | path ="/", |
| 189 | scheme = "http" | 153 | scheme = "http" |
| 190 | }) | 154 | }) |
| 191 | -- scheme has to be http | 155 | -- explicit info in reqt overrides that given by the URL |
| 192 | socket.try(parsed.scheme == "http", | 156 | for i,v in reqt do parsed[i] = v end |
| 193 | string.format("unknown scheme '%s'", parsed.scheme)) | 157 | -- compute uri if user hasn't overriden |
| 194 | -- explicit authentication info overrides that given by the URL | 158 | parsed.uri = parsed.uri or uri(parsed) |
| 195 | parsed.user = reqt.user or parsed.user | 159 | -- adjust headers in request |
| 196 | parsed.password = reqt.password or parsed.password | 160 | parsed.headers = adjustheaders(parsed.headers, parsed.host) |
| 197 | -- store results | 161 | return parsed |
| 198 | tmp.parsed = parsed | ||
| 199 | end | 162 | end |
| 200 | 163 | ||
| 201 | -- forward declaration | 164 | local function shouldredirect(reqt, respt) |
| 202 | local request_p | 165 | return (reqt.redirect ~= false) and |
| 166 | (respt.code == 301 or respt.code == 302) and | ||
| 167 | (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") | ||
| 168 | and (not reqt.nredirects or reqt.nredirects < 5) | ||
| 169 | end | ||
| 203 | 170 | ||
| 204 | local function should_authorize(reqt, respt, tmp) | 171 | local function shouldauthorize(reqt, respt) |
| 205 | -- if there has been an authorization attempt, it must have failed | 172 | -- if there has been an authorization attempt, it must have failed |
| 206 | if reqt.headers and reqt.headers["authorization"] then return nil end | 173 | if reqt.headers and reqt.headers["authorization"] then return nil end |
| 207 | -- if last attempt didn't fail due to lack of authentication, | 174 | -- if last attempt didn't fail due to lack of authentication, |
| 208 | -- or we don't have authorization information, we can't retry | 175 | -- or we don't have authorization information, we can't retry |
| 209 | return respt.code == 401 and tmp.parsed.user and tmp.parsed.password | 176 | return respt.code == 401 and reqt.user and reqt.password |
| 210 | end | 177 | end |
| 211 | 178 | ||
| 212 | local function clone(headers) | 179 | local function shouldreceivebody(reqt, respt) |
| 213 | if not headers then return nil end | 180 | if reqt.method == "HEAD" then return nil end |
| 214 | local copy = {} | 181 | local code = respt.code |
| 215 | for i,v in pairs(headers) do | 182 | if code == 204 or code == 304 then return nil end |
| 216 | copy[i] = v | 183 | if code >= 100 and code < 200 then return nil end |
| 217 | end | 184 | return 1 |
| 218 | return copy | ||
| 219 | end | 185 | end |
| 220 | 186 | ||
| 221 | local function authorize(reqt, respt, tmp) | 187 | local requestp, authorizep, redirectp |
| 222 | local headers = clone(reqt.headers) or {} | 188 | |
| 223 | headers["authorization"] = "Basic " .. | 189 | function requestp(reqt) |
| 224 | (mime.b64(tmp.parsed.user .. ":" .. tmp.parsed.password)) | 190 | local reqt = adjustrequest(reqt) |
| 225 | local autht = { | 191 | local respt = {} |
| 226 | method = reqt.method, | 192 | local con = open(reqt.host, reqt.port) |
| 227 | url = reqt.url, | 193 | con:sendrequestline(reqt.method, reqt.uri) |
| 228 | source = reqt.source, | 194 | con:sendheaders(reqt.headers) |
| 229 | sink = reqt.sink, | 195 | con:sendbody(reqt.headers, reqt.source, reqt.step) |
| 230 | headers = headers, | 196 | respt.code, respt.status = con:receivestatusline() |
| 231 | timeout = reqt.timeout, | 197 | respt.headers = con:receiveheaders() |
| 232 | proxy = reqt.proxy, | 198 | if shouldredirect(reqt, respt) then |
| 233 | } | 199 | con:close() |
| 234 | request_p(autht, respt, tmp) | 200 | return redirectp(reqt, respt) |
| 201 | elseif shouldauthorize(reqt, respt) then | ||
| 202 | con:close() | ||
| 203 | return authorizep(reqt, respt) | ||
| 204 | elseif shouldreceivebody(reqt, respt) then | ||
| 205 | con:receivebody(respt.headers, reqt.sink, reqt.step) | ||
| 206 | end | ||
| 207 | con:close() | ||
| 208 | return respt | ||
| 235 | end | 209 | end |
| 236 | 210 | ||
| 237 | local function should_redirect(reqt, respt, tmp) | 211 | function authorizep(reqt, respt) |
| 238 | return (reqt.redirect ~= false) and | 212 | local auth = "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) |
| 239 | (respt.code == 301 or respt.code == 302) and | 213 | reqt.headers["authorization"] = auth |
| 240 | (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") | 214 | return requestp(reqt) |
| 241 | and (not tmp.nredirects or tmp.nredirects < 5) | ||
| 242 | end | 215 | end |
| 243 | 216 | ||
| 244 | local function redirect(reqt, respt, tmp) | 217 | function redirectp(reqt, respt) |
| 245 | tmp.nredirects = (tmp.nredirects or 0) + 1 | 218 | -- we create a new table to get rid of anything we don't |
| 219 | -- absolutely need, including authentication info | ||
| 246 | local redirt = { | 220 | local redirt = { |
| 247 | method = reqt.method, | 221 | method = reqt.method, |
| 248 | -- the RFC says the redirect URL has to be absolute, but some | 222 | -- the RFC says the redirect URL has to be absolute, but some |
| @@ -251,69 +225,38 @@ local function redirect(reqt, respt, tmp) | |||
| 251 | source = reqt.source, | 225 | source = reqt.source, |
| 252 | sink = reqt.sink, | 226 | sink = reqt.sink, |
| 253 | headers = reqt.headers, | 227 | headers = reqt.headers, |
| 254 | timeout = reqt.timeout, | 228 | proxy = reqt.proxy, |
| 255 | proxy = reqt.proxy | 229 | nredirects = (reqt.nredirects or 0) + 1 |
| 256 | } | 230 | } |
| 257 | request_p(redirt, respt, tmp) | 231 | respt = requestp(redirt) |
| 258 | -- we pass the location header as a clue we redirected | 232 | -- we pass the location header as a clue we redirected |
| 259 | if respt.headers then respt.headers.location = redirt.url end | 233 | if respt.headers then respt.headers.location = redirt.url end |
| 260 | end | ||
| 261 | |||
| 262 | local function skip_continue(reqt, respt, tmp) | ||
| 263 | if respt.code == 100 then | ||
| 264 | receive_status(reqt, respt, tmp) | ||
| 265 | end | ||
| 266 | end | ||
| 267 | |||
| 268 | -- execute a request of through an exception | ||
| 269 | function request_p(reqt, respt, tmp) | ||
| 270 | parse_url(reqt, respt, tmp) | ||
| 271 | adjust_headers(reqt, respt, tmp) | ||
| 272 | open(reqt, respt, tmp) | ||
| 273 | send_request(reqt, respt, tmp) | ||
| 274 | receive_status(reqt, respt, tmp) | ||
| 275 | skip_continue(reqt, respt, tmp) | ||
| 276 | receive_headers(reqt, respt, tmp) | ||
| 277 | if should_redirect(reqt, respt, tmp) then | ||
| 278 | tmp.sock:close() | ||
| 279 | redirect(reqt, respt, tmp) | ||
| 280 | elseif should_authorize(reqt, respt, tmp) then | ||
| 281 | tmp.sock:close() | ||
| 282 | authorize(reqt, respt, tmp) | ||
| 283 | elseif should_receive_body(reqt, respt, tmp) then | ||
| 284 | receive_body(reqt, respt, tmp) | ||
| 285 | end | ||
| 286 | end | ||
| 287 | |||
| 288 | function request(reqt) | ||
| 289 | local respt, tmp = {}, {} | ||
| 290 | local s, e = pcall(request_p, reqt, respt, tmp) | ||
| 291 | if not s then respt.error = e end | ||
| 292 | if tmp.sock then tmp.sock:close() end | ||
| 293 | return respt | 234 | return respt |
| 294 | end | 235 | end |
| 295 | 236 | ||
| 296 | function get(u) | 237 | request = socket.protect(requestp) |
| 238 | |||
| 239 | get = socket.protect(function(u) | ||
| 297 | local t = {} | 240 | local t = {} |
| 298 | respt = request { | 241 | local respt = requestp { |
| 299 | url = u, | 242 | url = u, |
| 300 | sink = ltn12.sink.table(t) | 243 | sink = ltn12.sink.table(t) |
| 301 | } | 244 | } |
| 302 | return (table.getn(t) > 0 or nil) and table.concat(t), respt.headers, | 245 | return (table.getn(t) > 0 or nil) and table.concat(t), respt.headers, |
| 303 | respt.code, respt.error | 246 | respt.code |
| 304 | end | 247 | end) |
| 305 | 248 | ||
| 306 | function post(u, body) | 249 | post = socket.protect(function(u, body) |
| 307 | local t = {} | 250 | local t = {} |
| 308 | respt = request { | 251 | local respt = requestp { |
| 309 | url = u, | 252 | url = u, |
| 310 | method = "POST", | 253 | method = "POST", |
| 311 | source = ltn12.source.string(body), | 254 | source = ltn12.source.string(body), |
| 312 | sink = ltn12.sink.table(t), | 255 | sink = ltn12.sink.table(t), |
| 313 | headers = { ["content-length"] = string.len(body) } | 256 | headers = { ["content-length"] = string.len(body) } |
| 314 | } | 257 | } |
| 315 | return (table.getn(t) > 0 or nil) and table.concat(t), | 258 | return (table.getn(t) > 0 or nil) and table.concat(t), |
| 316 | respt.headers, respt.code, respt.error | 259 | respt.headers, respt.code |
| 317 | end | 260 | end) |
| 318 | 261 | ||
| 319 | return http | 262 | return http |
