diff options
| author | Diego Nehab <diego@tecgraf.puc-rio.br> | 2004-06-16 21:56:23 +0000 |
|---|---|---|
| committer | Diego Nehab <diego@tecgraf.puc-rio.br> | 2004-06-16 21:56:23 +0000 |
| commit | 574708380f19b15bd19419bfd64ccbe422f2d924 (patch) | |
| tree | deb5030132ca370193d1d64bd57fbcce4c7ce8eb | |
| parent | ba2f0b8c6ba7fb3a26fa6d9676ee1aefe6d873cc (diff) | |
| download | luasocket-574708380f19b15bd19419bfd64ccbe422f2d924.tar.gz luasocket-574708380f19b15bd19419bfd64ccbe422f2d924.tar.bz2 luasocket-574708380f19b15bd19419bfd64ccbe422f2d924.zip | |
Simplified HTTP module.
| -rw-r--r-- | src/http.lua | 138 | ||||
| -rw-r--r-- | src/url.lua | 2 | ||||
| -rw-r--r-- | test/httptest.lua | 28 |
3 files changed, 81 insertions, 87 deletions
diff --git a/src/http.lua b/src/http.lua index 8f3fdb9..e0c4c27 100644 --- a/src/http.lua +++ b/src/http.lua | |||
| @@ -143,117 +143,111 @@ local function adjustheaders(headers, host) | |||
| 143 | return lower | 143 | return lower |
| 144 | end | 144 | end |
| 145 | 145 | ||
| 146 | local default = { | ||
| 147 | host = "", | ||
| 148 | port = PORT, | ||
| 149 | path ="/", | ||
| 150 | scheme = "http" | ||
| 151 | } | ||
| 152 | |||
| 146 | local function adjustrequest(reqt) | 153 | local function adjustrequest(reqt) |
| 147 | -- parse url with default fields | 154 | -- parse url if provided |
| 148 | local parsed = url.parse(reqt.url or "", { | 155 | if reqt.url then |
| 149 | host = "", | 156 | local parsed = url.parse(reqt.url, default) |
| 150 | port = PORT, | 157 | -- explicit components override url |
| 151 | path ="/", | 158 | for i,v in parsed do reqt[i] = reqt[i] or v end |
| 152 | scheme = "http" | 159 | end |
| 153 | }) | 160 | socket.try(reqt.host, "invalid host '" .. tostring(reqt.host) .. "'") |
| 154 | -- explicit info in reqt overrides that given by the URL | 161 | socket.try(reqt.path, "invalid path '" .. tostring(reqt.path) .. "'") |
| 155 | for i,v in reqt do parsed[i] = v end | ||
| 156 | -- compute uri if user hasn't overriden | 162 | -- compute uri if user hasn't overriden |
| 157 | parsed.uri = parsed.uri or uri(parsed) | 163 | reqt.uri = reqt.uri or uri(reqt) |
| 158 | -- adjust headers in request | 164 | -- adjust headers in request |
| 159 | parsed.headers = adjustheaders(parsed.headers, parsed.host) | 165 | reqt.headers = adjustheaders(reqt.headers, reqt.host) |
| 160 | return parsed | 166 | return reqt |
| 161 | end | 167 | end |
| 162 | 168 | ||
| 163 | local function shouldredirect(reqt, respt) | 169 | local function shouldredirect(reqt, code) |
| 164 | return (reqt.redirect ~= false) and | 170 | return (reqt.redirect ~= false) and |
| 165 | (respt.code == 301 or respt.code == 302) and | 171 | (code == 301 or code == 302) and |
| 166 | (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") | 172 | (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") |
| 167 | and (not reqt.nredirects or reqt.nredirects < 5) | 173 | and (not reqt.nredirects or reqt.nredirects < 5) |
| 168 | end | 174 | end |
| 169 | 175 | ||
| 170 | local function shouldauthorize(reqt, respt) | 176 | local function shouldauthorize(reqt, code) |
| 171 | -- if there has been an authorization attempt, it must have failed | 177 | -- if there has been an authorization attempt, it must have failed |
| 172 | if reqt.headers and reqt.headers["authorization"] then return nil end | 178 | if reqt.headers and reqt.headers["authorization"] then return nil end |
| 173 | -- if last attempt didn't fail due to lack of authentication, | 179 | -- if last attempt didn't fail due to lack of authentication, |
| 174 | -- or we don't have authorization information, we can't retry | 180 | -- or we don't have authorization information, we can't retry |
| 175 | return respt.code == 401 and reqt.user and reqt.password | 181 | return code == 401 and reqt.user and reqt.password |
| 176 | end | 182 | end |
| 177 | 183 | ||
| 178 | local function shouldreceivebody(reqt, respt) | 184 | local function shouldreceivebody(reqt, code) |
| 179 | if reqt.method == "HEAD" then return nil end | 185 | if reqt.method == "HEAD" then return nil end |
| 180 | local code = respt.code | ||
| 181 | if code == 204 or code == 304 then return nil end | 186 | if code == 204 or code == 304 then return nil end |
| 182 | if code >= 100 and code < 200 then return nil end | 187 | if code >= 100 and code < 200 then return nil end |
| 183 | return 1 | 188 | return 1 |
| 184 | end | 189 | end |
| 185 | 190 | ||
| 186 | local requestp, authorizep, redirectp | 191 | -- forward declarations |
| 187 | 192 | local trequest, tauthorize, tredirect | |
| 188 | function requestp(reqt) | ||
| 189 | local reqt = adjustrequest(reqt) | ||
| 190 | local respt = {} | ||
| 191 | local con = open(reqt.host, reqt.port) | ||
| 192 | con:sendrequestline(reqt.method, reqt.uri) | ||
| 193 | con:sendheaders(reqt.headers) | ||
| 194 | con:sendbody(reqt.headers, reqt.source, reqt.step) | ||
| 195 | respt.code, respt.status = con:receivestatusline() | ||
| 196 | respt.headers = con:receiveheaders() | ||
| 197 | if shouldredirect(reqt, respt) then | ||
| 198 | con:close() | ||
| 199 | return redirectp(reqt, respt) | ||
| 200 | elseif shouldauthorize(reqt, respt) then | ||
| 201 | con:close() | ||
| 202 | return authorizep(reqt, respt) | ||
| 203 | elseif shouldreceivebody(reqt, respt) then | ||
| 204 | con:receivebody(respt.headers, reqt.sink, reqt.step) | ||
| 205 | end | ||
| 206 | con:close() | ||
| 207 | return respt | ||
| 208 | end | ||
| 209 | 193 | ||
| 210 | function authorizep(reqt, respt) | 194 | function tauthorize(reqt) |
| 211 | local auth = "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) | 195 | local auth = "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) |
| 212 | reqt.headers["authorization"] = auth | 196 | reqt.headers["authorization"] = auth |
| 213 | return requestp(reqt) | 197 | return trequest(reqt) |
| 214 | end | 198 | end |
| 215 | 199 | ||
| 216 | function redirectp(reqt, respt) | 200 | function tredirect(reqt, headers) |
| 217 | -- we create a new table to get rid of anything we don't | 201 | -- the RFC says the redirect URL has to be absolute, but some |
| 218 | -- absolutely need, including authentication info | 202 | -- servers do not respect that |
| 219 | local redirt = { | 203 | return trequest { |
| 220 | method = reqt.method, | 204 | url = url.absolute(reqt, headers["location"]), |
| 221 | -- the RFC says the redirect URL has to be absolute, but some | ||
| 222 | -- servers do not respect that | ||
| 223 | url = url.absolute(reqt.url, respt.headers["location"]), | ||
| 224 | source = reqt.source, | 205 | source = reqt.source, |
| 225 | sink = reqt.sink, | 206 | sink = reqt.sink, |
| 226 | headers = reqt.headers, | 207 | headers = reqt.headers, |
| 227 | proxy = reqt.proxy, | 208 | proxy = reqt.proxy, |
| 228 | nredirects = (reqt.nredirects or 0) + 1 | 209 | nredirects = (reqt.nredirects or 0) + 1 |
| 229 | } | 210 | } |
| 230 | respt = requestp(redirt) | ||
| 231 | -- we pass the location header as a clue we redirected | ||
| 232 | if respt.headers then respt.headers.location = redirt.url end | ||
| 233 | return respt | ||
| 234 | end | 211 | end |
| 235 | 212 | ||
| 236 | request = socket.protect(requestp) | 213 | function trequest(reqt) |
| 214 | reqt = adjustrequest(reqt) | ||
| 215 | local con = open(reqt.host, reqt.port) | ||
| 216 | con:sendrequestline(reqt.method, reqt.uri) | ||
| 217 | con:sendheaders(reqt.headers) | ||
| 218 | con:sendbody(reqt.headers, reqt.source, reqt.step) | ||
| 219 | local code, headers, status | ||
| 220 | code, status = con:receivestatusline() | ||
| 221 | headers = con:receiveheaders() | ||
| 222 | if shouldredirect(reqt, code) then | ||
| 223 | con:close() | ||
| 224 | return tredirect(reqt, headers) | ||
| 225 | elseif shouldauthorize(reqt, code) then | ||
| 226 | con:close() | ||
| 227 | return tauthorize(reqt) | ||
| 228 | elseif shouldreceivebody(reqt, code) then | ||
| 229 | con:receivebody(headers, reqt.sink, reqt.step) | ||
| 230 | end | ||
| 231 | con:close() | ||
| 232 | return 1, code, headers, status | ||
| 233 | end | ||
| 237 | 234 | ||
| 238 | get = socket.protect(function(u) | 235 | local function srequest(u, body) |
| 239 | local t = {} | 236 | local t = {} |
| 240 | local respt = requestp { | 237 | local reqt = { |
| 241 | url = u, | 238 | url = u, |
| 242 | sink = ltn12.sink.table(t) | 239 | sink = ltn12.sink.table(t) |
| 243 | } | 240 | } |
| 244 | return (table.getn(t) > 0 or nil) and table.concat(t), respt.headers, | 241 | if body then |
| 245 | respt.code | 242 | reqt.source = ltn12.source.string(body) |
| 246 | end) | 243 | reqt.headers = { ["content-length"] = string.len(body) } |
| 244 | reqt.method = "POST" | ||
| 245 | end | ||
| 246 | local code, headers, status = socket.skip(1, trequest(reqt)) | ||
| 247 | return table.concat(t), code, headers, status | ||
| 248 | end | ||
| 247 | 249 | ||
| 248 | post = socket.protect(function(u, body) | 250 | request = socket.protect(function(reqt, body) |
| 249 | local t = {} | 251 | if type(reqt) == "string" then return srequest(reqt, body) |
| 250 | local respt = requestp { | 252 | else return trequest(reqt) end |
| 251 | url = u, | ||
| 252 | method = "POST", | ||
| 253 | source = ltn12.source.string(body), | ||
| 254 | sink = ltn12.sink.table(t), | ||
| 255 | headers = { ["content-length"] = string.len(body) } | ||
| 256 | } | ||
| 257 | return (table.getn(t) > 0 or nil) and table.concat(t), | ||
| 258 | respt.headers, respt.code | ||
| 259 | end) | 253 | end) |
diff --git a/src/url.lua b/src/url.lua index 960a248..ec26e62 100644 --- a/src/url.lua +++ b/src/url.lua | |||
| @@ -190,7 +190,7 @@ end | |||
| 190 | -- corresponding absolute url | 190 | -- corresponding absolute url |
| 191 | ----------------------------------------------------------------------------- | 191 | ----------------------------------------------------------------------------- |
| 192 | function absolute(base_url, relative_url) | 192 | function absolute(base_url, relative_url) |
| 193 | local base = parse(base_url) | 193 | local base = type(base_url) == "table" and base_url or parse(base_url) |
| 194 | local relative = parse(relative_url) | 194 | local relative = parse(relative_url) |
| 195 | if not base then return relative_url | 195 | if not base then return relative_url |
| 196 | elseif not relative then return base_url | 196 | elseif not relative then return base_url |
diff --git a/test/httptest.lua b/test/httptest.lua index 61dc60a..a171dd9 100644 --- a/test/httptest.lua +++ b/test/httptest.lua | |||
| @@ -55,12 +55,12 @@ end | |||
| 55 | 55 | ||
| 56 | local check_request = function(request, expect, ignore) | 56 | local check_request = function(request, expect, ignore) |
| 57 | local t | 57 | local t |
| 58 | if not request.sink then | 58 | if not request.sink then request.sink, t = ltn12.sink.table() end |
| 59 | request.sink, t = ltn12.sink.table(t) | ||
| 60 | end | ||
| 61 | request.source = request.source or | 59 | request.source = request.source or |
| 62 | (request.body and ltn12.source.string(request.body)) | 60 | (request.body and ltn12.source.string(request.body)) |
| 63 | local response = http.request(request) | 61 | local response = {} |
| 62 | response.code, response.headers, response.status = | ||
| 63 | socket.skip(1, http.request(request)) | ||
| 64 | if t and table.getn(t) > 0 then response.body = table.concat(t) end | 64 | if t and table.getn(t) > 0 then response.body = table.concat(t) end |
| 65 | check_result(response, expect, ignore) | 65 | check_result(response, expect, ignore) |
| 66 | end | 66 | end |
| @@ -68,8 +68,8 @@ end | |||
| 68 | ------------------------------------------------------------------------ | 68 | ------------------------------------------------------------------------ |
| 69 | io.write("testing request uri correctness: ") | 69 | io.write("testing request uri correctness: ") |
| 70 | local forth = cgiprefix .. "/request-uri?" .. "this+is+the+query+string" | 70 | local forth = cgiprefix .. "/request-uri?" .. "this+is+the+query+string" |
| 71 | local back, h, c, e = http.get("http://" .. host .. forth) | 71 | local back, c, h = http.request("http://" .. host .. forth) |
| 72 | if not back then fail(e) end | 72 | if not back then fail(c) end |
| 73 | back = url.parse(back) | 73 | back = url.parse(back) |
| 74 | if similar(back.query, "this+is+the+query+string") then print("ok") | 74 | if similar(back.query, "this+is+the+query+string") then print("ok") |
| 75 | else fail(back.query) end | 75 | else fail(back.query) end |
| @@ -77,7 +77,7 @@ else fail(back.query) end | |||
| 77 | ------------------------------------------------------------------------ | 77 | ------------------------------------------------------------------------ |
| 78 | io.write("testing query string correctness: ") | 78 | io.write("testing query string correctness: ") |
| 79 | forth = "this+is+the+query+string" | 79 | forth = "this+is+the+query+string" |
| 80 | back = http.get("http://" .. host .. cgiprefix .. | 80 | back = http.request("http://" .. host .. cgiprefix .. |
| 81 | "/query-string?" .. forth) | 81 | "/query-string?" .. forth) |
| 82 | if similar(back, forth) then print("ok") | 82 | if similar(back, forth) then print("ok") |
| 83 | else fail("failed!") end | 83 | else fail("failed!") end |
| @@ -153,7 +153,7 @@ check_request(request, expect, ignore) | |||
| 153 | 153 | ||
| 154 | ------------------------------------------------------------------------ | 154 | ------------------------------------------------------------------------ |
| 155 | io.write("testing simple post function: ") | 155 | io.write("testing simple post function: ") |
| 156 | back = http.post("http://" .. host .. cgiprefix .. "/cat", index) | 156 | back = http.request("http://" .. host .. cgiprefix .. "/cat", index) |
| 157 | assert(back == index) | 157 | assert(back == index) |
| 158 | 158 | ||
| 159 | ------------------------------------------------------------------------ | 159 | ------------------------------------------------------------------------ |
| @@ -378,19 +378,19 @@ check_request(request, expect, ignore) | |||
| 378 | 378 | ||
| 379 | ------------------------------------------------------------------------ | 379 | ------------------------------------------------------------------------ |
| 380 | local body | 380 | local body |
| 381 | io.write("testing simple get function: ") | 381 | io.write("testing simple request function: ") |
| 382 | body = http.get("http://" .. host .. prefix .. "/index.html") | 382 | body = http.request("http://" .. host .. prefix .. "/index.html") |
| 383 | assert(body == index) | 383 | assert(body == index) |
| 384 | print("ok") | 384 | print("ok") |
| 385 | 385 | ||
| 386 | ------------------------------------------------------------------------ | 386 | ------------------------------------------------------------------------ |
| 387 | io.write("testing HEAD method: ") | 387 | io.write("testing HEAD method: ") |
| 388 | http.TIMEOUT = 1 | 388 | http.TIMEOUT = 1 |
| 389 | response = http.request { | 389 | local r, c, h = http.request { |
| 390 | method = "HEAD", | 390 | method = "HEAD", |
| 391 | url = "http://www.cs.princeton.edu/~diego/" | 391 | url = "http://www.cs.princeton.edu/~diego/" |
| 392 | } | 392 | } |
| 393 | assert(response and response.headers) | 393 | assert(r and h and c == 200) |
| 394 | print("ok") | 394 | print("ok") |
| 395 | 395 | ||
| 396 | ------------------------------------------------------------------------ | 396 | ------------------------------------------------------------------------ |
| @@ -398,7 +398,7 @@ io.write("testing host not found: ") | |||
| 398 | local c, e = socket.connect("wronghost", 80) | 398 | local c, e = socket.connect("wronghost", 80) |
| 399 | local r, re = http.request{url = "http://wronghost/does/not/exist"} | 399 | local r, re = http.request{url = "http://wronghost/does/not/exist"} |
| 400 | assert(r == nil and e == re) | 400 | assert(r == nil and e == re) |
| 401 | r, re = http.get("http://wronghost/does/not/exist") | 401 | r, re = http.request("http://wronghost/does/not/exist") |
| 402 | assert(r == nil and e == re) | 402 | assert(r == nil and e == re) |
| 403 | print("ok") | 403 | print("ok") |
| 404 | 404 | ||
| @@ -407,7 +407,7 @@ io.write("testing invalid url: ") | |||
| 407 | local c, e = socket.connect("", 80) | 407 | local c, e = socket.connect("", 80) |
| 408 | local r, re = http.request{url = host .. prefix} | 408 | local r, re = http.request{url = host .. prefix} |
| 409 | assert(r == nil and e == re) | 409 | assert(r == nil and e == re) |
| 410 | r, re = http.get(host .. prefix) | 410 | r, re = http.request(host .. prefix) |
| 411 | assert(r == nil and e == re) | 411 | assert(r == nil and e == re) |
| 412 | print("ok") | 412 | print("ok") |
| 413 | 413 | ||
