aboutsummaryrefslogtreecommitdiff
path: root/src/http.lua
blob: b265650333d2b39f930c0c5a5a54d5f68f86468a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
-----------------------------------------------------------------------------
-- HTTP/1.1 client support for the Lua language.
-- LuaSocket toolkit.
-- Author: Diego Nehab
-- RCS ID: $Id$
-----------------------------------------------------------------------------

-----------------------------------------------------------------------------
-- Declare module and import dependencies
-------------------------------------------------------------------------------
local socket = require("socket")
local url = require("socket.url")
local ltn12 = require("ltn12")
local mime = require("mime")

module("socket.http")

-----------------------------------------------------------------------------
-- Program constants
-----------------------------------------------------------------------------
-- connection timeout in seconds
TIMEOUT = 60 
-- default port for document retrieval
PORT = 80
-- user agent field sent in request
USERAGENT = socket.VERSION

-----------------------------------------------------------------------------
-- Low level HTTP API
-----------------------------------------------------------------------------
local metat = { __index = {} }

function open(host, port)
    local c = socket.try(socket.tcp()) 
    local h = setmetatable({ c = c }, metat)
    -- make sure the connection gets closed on exception
    h.try = socket.newtry(function() h:close() end)
    h.try(c:settimeout(TIMEOUT))
    h.try(c:connect(host, port or PORT))
    return h 
end

function metat.__index:sendrequestline(method, uri)
    local reqline = string.format("%s %s HTTP/1.1\r\n", method or "GET", uri) 
    return self.try(self.c:send(reqline))
end

function metat.__index:sendheaders(headers)
    for i, v in pairs(headers) do
        self.try(self.c:send(i .. ": " .. v .. "\r\n"))
    end
    -- mark end of request headers
    self.try(self.c:send("\r\n"))
    return 1
end

function metat.__index:sendbody(headers, source, step)
    source = source or ltn12.source.empty()
    step = step or ltn12.pump.step
    -- if we don't know the size in advance, send chunked and hope for the best
    local mode = "http-chunked"
    if headers["content-length"] then mode = "keep-open" end
    return self.try(ltn12.pump.all(source, socket.sink(mode, self.c), step))
end

function metat.__index:receivestatusline()
    local status = self.try(self.c:receive())
    local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)"))
    return self.try(tonumber(code), status)
end

function metat.__index:receiveheaders()
    local line, name, value
    local headers = {}
    -- get first line
    line = self.try(self.c:receive())
    -- headers go until a blank line is found
    while line ~= "" do
        -- get field-name and value
        name, value = socket.skip(2, string.find(line, "^(.-):%s*(.*)"))
        self.try(name and value, "malformed reponse headers")
        name = string.lower(name)
        -- get next line (value might be folded)
        line  = self.try(self.c:receive())
        -- unfold any folded values
        while string.find(line, "^%s") do
            value = value .. line
            line = self.try(self.c:receive())
        end
        -- save pair in table
        if headers[name] then headers[name] = headers[name] .. ", " .. value
        else headers[name] = value end
    end
    return headers
end

function metat.__index:receivebody(headers, sink, step)
    sink = sink or ltn12.sink.null()
    step = step or ltn12.pump.step
    local length = tonumber(headers["content-length"])
    local TE = headers["transfer-encoding"]
    local mode = "default" -- connection close
    if TE and TE ~= "identity" then mode = "http-chunked"
    elseif tonumber(headers["content-length"]) then mode = "by-length" end
    return self.try(ltn12.pump.all(socket.source(mode, self.c, length), 
        sink, step))
end

function metat.__index:close()
    return self.c:close()
end

-----------------------------------------------------------------------------
-- High level HTTP API
-----------------------------------------------------------------------------
local function adjusturi(reqt)
    local u = reqt
    -- if there is a proxy, we need the full url. otherwise, just a part.
    if not reqt.proxy and not PROXY then
        u = {
           path = socket.try(reqt.path, "invalid path 'nil'"),
           params = reqt.params,
           query = reqt.query,
           fragment = reqt.fragment
        }
    end
    return url.build(u)
end

local function adjustproxy(reqt)
    local proxy = reqt.proxy or PROXY 
    if proxy then
        proxy = url.parse(proxy)
        return proxy.host, proxy.port or 3128
    else
        return reqt.host, reqt.port
    end
end

local function adjustheaders(headers, host)
    local lower = {}
    -- override with user values
    for i,v in (headers or lower) do
        lower[string.lower(i)] = v
    end
    lower["user-agent"] = lower["user-agent"] or USERAGENT
    -- these cannot be overriden
    lower["host"] = host
    return lower
end

local default = {
    host = "",
    port = PORT,
    path ="/",
    scheme = "http"
}

local function adjustrequest(reqt)
    -- parse url if provided
    local nreqt = reqt.url and url.parse(reqt.url, default) or {}
    -- explicit components override url
    for i,v in reqt do nreqt[i] = reqt[i] end
    socket.try(nreqt.host, "invalid host '" .. tostring(nreqt.host) .. "'")
    -- compute uri if user hasn't overriden
    nreqt.uri = reqt.uri or adjusturi(nreqt)
    -- ajust host and port if there is a proxy
    nreqt.host, nreqt.port = adjustproxy(nreqt)
    -- adjust headers in request
    nreqt.headers = adjustheaders(nreqt.headers, nreqt.host)
    return nreqt
end

local function shouldredirect(reqt, code)
    return (reqt.redirect ~= false) and
           (code == 301 or code == 302) and
           (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD")
           and (not reqt.nredirects or reqt.nredirects < 5)
end

local function shouldauthorize(reqt, code)
    -- if there has been an authorization attempt, it must have failed
    if reqt.headers and reqt.headers["authorization"] then return nil end
    -- if last attempt didn't fail due to lack of authentication,
    -- or we don't have authorization information, we can't retry
    return code == 401 and reqt.user and reqt.password
end

local function shouldreceivebody(reqt, code)
    if reqt.method == "HEAD" then return nil end
    if code == 204 or code == 304 then return nil end
    if code >= 100 and code < 200 then return nil end
    return 1
end

-- forward declarations
local trequest, tauthorize, tredirect

function tauthorize(reqt)
    local auth = "Basic " ..  (mime.b64(reqt.user .. ":" .. reqt.password))
    reqt.headers["authorization"] = auth
    return trequest(reqt)
end

function tredirect(reqt, headers)
    return trequest {
        -- the RFC says the redirect URL has to be absolute, but some
        -- servers do not respect that
        url = url.absolute(reqt, headers["location"]),
        source = reqt.source,
        sink = reqt.sink,
        headers = reqt.headers,
        proxy = reqt.proxy,
        nredirects = (reqt.nredirects or 0) + 1
    }
end

function trequest(reqt)
    reqt = adjustrequest(reqt)
    local h = open(reqt.host, reqt.port)
    h:sendrequestline(reqt.method, reqt.uri)
    h:sendheaders(reqt.headers)
    h:sendbody(reqt.headers, reqt.source, reqt.step)
    local code, headers, status
    code, status = h:receivestatusline()
    headers = h:receiveheaders()
    if shouldredirect(reqt, code) then 
        h:close()
        return tredirect(reqt, headers)
    elseif shouldauthorize(reqt, code) then 
        h:close()
        return tauthorize(reqt)
    elseif shouldreceivebody(reqt, code) then
        h:receivebody(headers, reqt.sink, reqt.step)
    end
    h:close()
    return 1, code, headers, status
end

local function srequest(u, body)
    local t = {}
    local reqt = { 
        url = u,
        sink = ltn12.sink.table(t)
    }
    if body then 
        reqt.source = ltn12.source.string(body) 
        reqt.headers = { ["content-length"] = string.len(body) }
        reqt.method = "POST"
    end
    local code, headers, status = socket.skip(1, trequest(reqt))
    return table.concat(t), code, headers, status
end

request = socket.protect(function(reqt, body)
    if type(reqt) == "string" then return srequest(reqt, body)
    else return trequest(reqt) end
end)