diff options
-rw-r--r-- | samples/forward.lua | 93 | ||||
-rw-r--r-- | src/http.lua | 28 | ||||
-rw-r--r-- | test/httptest.lua | 2 |
3 files changed, 63 insertions, 60 deletions
diff --git a/samples/forward.lua b/samples/forward.lua index de651b4..c3f0605 100644 --- a/samples/forward.lua +++ b/samples/forward.lua | |||
@@ -1,11 +1,5 @@ | |||
1 | -- load our favourite library | 1 | -- load our favourite library |
2 | local socket = require"socket" | 2 | local socket = require"socket" |
3 | -- timeout before an inactive thread is kicked | ||
4 | local TIMEOUT = 10 | ||
5 | -- local address to bind to | ||
6 | local ihost, iport = arg[1] or "localhost", arg[2] or 8080 | ||
7 | -- address to forward all data to | ||
8 | local ohost, oport = arg[3] or "localhost", arg[4] or 3128 | ||
9 | 3 | ||
10 | -- creates a new set data structure | 4 | -- creates a new set data structure |
11 | function newset() | 5 | function newset() |
@@ -32,12 +26,44 @@ function newset() | |||
32 | }}) | 26 | }}) |
33 | end | 27 | end |
34 | 28 | ||
29 | -- timeout before an inactive thread is kicked | ||
30 | local TIMEOUT = 10 | ||
31 | -- set of connections waiting to receive data | ||
35 | local receiving = newset() | 32 | local receiving = newset() |
33 | -- set of sockets waiting to send data | ||
36 | local sending = newset() | 34 | local sending = newset() |
35 | -- context for connections and servers | ||
37 | local context = {} | 36 | local context = {} |
38 | 37 | ||
39 | -- starts a non-blocking connect | 38 | -- initializes the forward server |
40 | function nconnect(host, port) | 39 | function init() |
40 | if table.getn(arg) < 1 then | ||
41 | print("Usage") | ||
42 | print(" lua forward.lua <iport:ohost:oport> ...") | ||
43 | os.exit(1) | ||
44 | end | ||
45 | -- for each tunnel, start a new server socket | ||
46 | for i, v in ipairs(arg) do | ||
47 | -- capture forwarding parameters | ||
48 | local iport, ohost, oport = | ||
49 | socket.skip(2, string.find(v, "([^:]+):([^:]+):([^:]+)")) | ||
50 | assert(iport, "invalid arguments") | ||
51 | -- create our server socket | ||
52 | local server = assert(socket.bind("*", iport)) | ||
53 | server:settimeout(0.1) -- we don't want to be killed by bad luck | ||
54 | -- make sure server is tested for readability | ||
55 | receiving:insert(server) | ||
56 | -- add server context | ||
57 | context[server] = { | ||
58 | thread = coroutine.create(accept), | ||
59 | ohost = ohost, | ||
60 | oport = oport | ||
61 | } | ||
62 | end | ||
63 | end | ||
64 | |||
65 | -- starts a connection in a non-blocking way | ||
66 | function nbkcon(host, port) | ||
41 | local peer, err = socket.tcp() | 67 | local peer, err = socket.tcp() |
42 | if not peer then return nil, err end | 68 | if not peer then return nil, err end |
43 | peer:settimeout(0) | 69 | peer:settimeout(0) |
@@ -52,7 +78,6 @@ end | |||
52 | 78 | ||
53 | -- gets rid of a client | 79 | -- gets rid of a client |
54 | function kick(who) | 80 | function kick(who) |
55 | if who == server then error("FUDEU") end | ||
56 | if context[who] then | 81 | if context[who] then |
57 | sending:remove(who) | 82 | sending:remove(who) |
58 | receiving:remove(who) | 83 | receiving:remove(who) |
@@ -63,7 +88,6 @@ end | |||
63 | 88 | ||
64 | -- decides what to do with a thread based on coroutine return | 89 | -- decides what to do with a thread based on coroutine return |
65 | function route(who, status, what) | 90 | function route(who, status, what) |
66 | print(who, status, what) | ||
67 | if status and what then | 91 | if status and what then |
68 | if what == "receiving" then receiving:insert(who) end | 92 | if what == "receiving" then receiving:insert(who) end |
69 | if what == "sending" then sending:insert(who) end | 93 | if what == "sending" then sending:insert(who) end |
@@ -73,12 +97,13 @@ end | |||
73 | -- loops accepting connections and creating new threads to deal with them | 97 | -- loops accepting connections and creating new threads to deal with them |
74 | function accept(server) | 98 | function accept(server) |
75 | while true do | 99 | while true do |
76 | print(server, "accepting a new client") | ||
77 | -- accept a new connection and start a new coroutine to deal with it | 100 | -- accept a new connection and start a new coroutine to deal with it |
78 | local client = server:accept() | 101 | local client = server:accept() |
79 | if client then | 102 | if client then |
80 | -- start a new connection, non-blockingly, to the forwarding address | 103 | -- start a new connection, non-blockingly, to the forwarding address |
81 | local peer = nconnect(ohost, oport) | 104 | local ohost = context[server].ohost |
105 | local oport = context[server].oport | ||
106 | local peer = nbkcon(ohost, oport) | ||
82 | if peer then | 107 | if peer then |
83 | context[client] = { | 108 | context[client] = { |
84 | last = socket.gettime(), | 109 | last = socket.gettime(), |
@@ -90,7 +115,7 @@ print(server, "accepting a new client") | |||
90 | sending:insert(peer) | 115 | sending:insert(peer) |
91 | context[peer] = { | 116 | context[peer] = { |
92 | peer = client, | 117 | peer = client, |
93 | thread = coroutine.create(check), | 118 | thread = coroutine.create(chkcon), |
94 | last = socket.gettime() | 119 | last = socket.gettime() |
95 | } | 120 | } |
96 | -- put both in non-blocking mode | 121 | -- put both in non-blocking mode |
@@ -109,14 +134,12 @@ end | |||
109 | -- forwards all data arriving to the appropriate peer | 134 | -- forwards all data arriving to the appropriate peer |
110 | function forward(who) | 135 | function forward(who) |
111 | while true do | 136 | while true do |
112 | print(who, "getting data") | ||
113 | -- try to read as much as possible | 137 | -- try to read as much as possible |
114 | local data, rec_err, partial = who:receive("*a") | 138 | local data, rec_err, partial = who:receive("*a") |
115 | -- if we had an error other than timeout, abort | 139 | -- if we had an error other than timeout, abort |
116 | if rec_err and rec_err ~= "timeout" then return error(rec_err) end | 140 | if rec_err and rec_err ~= "timeout" then return error(rec_err) end |
117 | -- if we got a timeout, we probably have partial results to send | 141 | -- if we got a timeout, we probably have partial results to send |
118 | data = data or partial | 142 | data = data or partial |
119 | print(who, " got ", string.len(data)) | ||
120 | -- renew our timestamp so scheduler sees we are active | 143 | -- renew our timestamp so scheduler sees we are active |
121 | context[who].last = socket.gettime() | 144 | context[who].last = socket.gettime() |
122 | -- forward what we got right away | 145 | -- forward what we got right away |
@@ -126,7 +149,6 @@ print(who, " got ", string.len(data)) | |||
126 | coroutine.yield("sending") | 149 | coroutine.yield("sending") |
127 | local ret, snd_err | 150 | local ret, snd_err |
128 | local start = 0 | 151 | local start = 0 |
129 | print(who, "sending data") | ||
130 | ret, snd_err, start = peer:send(data, start+1) | 152 | ret, snd_err, start = peer:send(data, start+1) |
131 | if ret then break | 153 | if ret then break |
132 | elseif snd_err ~= "timeout" then return error(snd_err) end | 154 | elseif snd_err ~= "timeout" then return error(snd_err) end |
@@ -143,51 +165,22 @@ end | |||
143 | 165 | ||
144 | -- checks if a connection completed successfully and if it did, starts | 166 | -- checks if a connection completed successfully and if it did, starts |
145 | -- forwarding all data | 167 | -- forwarding all data |
146 | function check(who) | 168 | function chkcon(who) |
147 | local ret, err = who:connected() | 169 | local ret, err = who:connected() |
148 | if ret then | 170 | if ret then |
149 | print(who, "connection completed") | ||
150 | receiving:insert(context[who].peer) | 171 | receiving:insert(context[who].peer) |
151 | context[who].last = socket.gettime() | 172 | context[who].last = socket.gettime() |
152 | print(who, "yielding until there is input data") | ||
153 | coroutine.yield("receiving") | 173 | coroutine.yield("receiving") |
154 | return forward(who) | 174 | return forward(who) |
155 | else return error(err) end | 175 | else return error(err) end |
156 | end | 176 | end |
157 | 177 | ||
158 | -- initializes the forward server | ||
159 | function init() | ||
160 | -- socket sets to test for events | ||
161 | -- create our server socket | ||
162 | server = assert(socket.bind(ihost, iport)) | ||
163 | server:settimeout(0.1) -- we don't want to be killed by bad luck | ||
164 | -- we initially | ||
165 | receiving:insert(server) | ||
166 | context[server] = { thread = coroutine.create(accept) } | ||
167 | end | ||
168 | |||
169 | -- loop waiting until something happens, restarting the thread to deal with | 178 | -- loop waiting until something happens, restarting the thread to deal with |
170 | -- what happened, and routing it to wait until something else happens | 179 | -- what happened, and routing it to wait until something else happens |
171 | function go() | 180 | function go() |
172 | while true do | 181 | while true do |
173 | print("will select for readability") | ||
174 | for i,v in ipairs(receiving) do | ||
175 | print(i, v) | ||
176 | end | ||
177 | print("will select for writability") | ||
178 | for i,v in ipairs(sending) do | ||
179 | print(i, v) | ||
180 | end | ||
181 | -- check which sockets are interesting and act on them | 182 | -- check which sockets are interesting and act on them |
182 | readable, writable = socket.select(receiving, sending, 3) | 183 | readable, writable = socket.select(receiving, sending, 3) |
183 | print("returned as readable") | ||
184 | for i,v in ipairs(readable) do | ||
185 | print(i, v) | ||
186 | end | ||
187 | print("returned as writable") | ||
188 | for i,v in ipairs(writable) do | ||
189 | print(i, v) | ||
190 | end | ||
191 | -- for all readable connections, resume its thread and route it | 184 | -- for all readable connections, resume its thread and route it |
192 | for _, who in ipairs(readable) do | 185 | for _, who in ipairs(readable) do |
193 | receiving:remove(who) | 186 | receiving:remove(who) |
@@ -207,7 +200,6 @@ end | |||
207 | local deathrow | 200 | local deathrow |
208 | for who, data in pairs(context) do | 201 | for who, data in pairs(context) do |
209 | if data.last then | 202 | if data.last then |
210 | print("hung for" , now - data.last, who) | ||
211 | if now - data.last > TIMEOUT then | 203 | if now - data.last > TIMEOUT then |
212 | -- only create table if someone is doomed | 204 | -- only create table if someone is doomed |
213 | deathrow = deathrow or {} | 205 | deathrow = deathrow or {} |
@@ -217,13 +209,10 @@ print("hung for" , now - data.last, who) | |||
217 | end | 209 | end |
218 | -- finally kick everyone in deathrow | 210 | -- finally kick everyone in deathrow |
219 | if deathrow then | 211 | if deathrow then |
220 | print("in death row") | ||
221 | for i,v in pairs(deathrow) do | ||
222 | print(i, v) | ||
223 | end | ||
224 | for who in pairs(deathrow) do kick(who) end | 212 | for who in pairs(deathrow) do kick(who) end |
225 | end | 213 | end |
226 | end | 214 | end |
227 | end | 215 | end |
228 | 216 | ||
229 | go(init()) | 217 | init() |
218 | go() | ||
diff --git a/src/http.lua b/src/http.lua index 1dff11a..38b93e2 100644 --- a/src/http.lua +++ b/src/http.lua | |||
@@ -32,13 +32,26 @@ USERAGENT = socket.VERSION | |||
32 | ----------------------------------------------------------------------------- | 32 | ----------------------------------------------------------------------------- |
33 | local metat = { __index = {} } | 33 | local metat = { __index = {} } |
34 | 34 | ||
35 | function open(host, port) | 35 | -- default connect function, respecting the timeout |
36 | local c = socket.try(socket.tcp()) | 36 | local function connect(host, port) |
37 | local c, e = socket.tcp() | ||
38 | if not c then return nil, e end | ||
39 | c:settimeout(TIMEOUT) | ||
40 | local r, e = c:connect(host, port or PORT) | ||
41 | if not r then | ||
42 | c:close() | ||
43 | return nil, e | ||
44 | end | ||
45 | return c | ||
46 | end | ||
47 | |||
48 | function open(host, port, user) | ||
49 | -- create socket with user connect function, or with default | ||
50 | local c = socket.try((user or connect)(host, port)) | ||
51 | -- create our http request object, pointing to the socket | ||
37 | local h = base.setmetatable({ c = c }, metat) | 52 | local h = base.setmetatable({ c = c }, metat) |
38 | -- make sure the connection gets closed on exception | 53 | -- make sure the object close gets called on exception |
39 | h.try = socket.newtry(function() h:close() end) | 54 | h.try = socket.newtry(function() h:close() end) |
40 | h.try(c:settimeout(TIMEOUT)) | ||
41 | h.try(c:connect(host, port or PORT)) | ||
42 | return h | 55 | return h |
43 | end | 56 | end |
44 | 57 | ||
@@ -215,13 +228,14 @@ function tredirect(reqt, headers) | |||
215 | sink = reqt.sink, | 228 | sink = reqt.sink, |
216 | headers = reqt.headers, | 229 | headers = reqt.headers, |
217 | proxy = reqt.proxy, | 230 | proxy = reqt.proxy, |
218 | nredirects = (reqt.nredirects or 0) + 1 | 231 | nredirects = (reqt.nredirects or 0) + 1, |
232 | connect = reqt.connect | ||
219 | } | 233 | } |
220 | end | 234 | end |
221 | 235 | ||
222 | function trequest(reqt) | 236 | function trequest(reqt) |
223 | reqt = adjustrequest(reqt) | 237 | reqt = adjustrequest(reqt) |
224 | local h = open(reqt.host, reqt.port) | 238 | local h = open(reqt.host, reqt.port, reqt.connect) |
225 | h:sendrequestline(reqt.method, reqt.uri) | 239 | h:sendrequestline(reqt.method, reqt.uri) |
226 | h:sendheaders(reqt.headers) | 240 | h:sendheaders(reqt.headers) |
227 | h:sendbody(reqt.headers, reqt.source, reqt.step) | 241 | h:sendbody(reqt.headers, reqt.source, reqt.step) |
diff --git a/test/httptest.lua b/test/httptest.lua index 2335fcb..8862ceb 100644 --- a/test/httptest.lua +++ b/test/httptest.lua | |||
@@ -23,7 +23,7 @@ http.TIMEOUT = 10 | |||
23 | local t = socket.gettime() | 23 | local t = socket.gettime() |
24 | 24 | ||
25 | host = host or "diego.student.princeton.edu" | 25 | host = host or "diego.student.princeton.edu" |
26 | proxy = proxy or "http://localhost:3128" | 26 | proxy = proxy or "http://dell-diego:3128" |
27 | prefix = prefix or "/luasocket-test" | 27 | prefix = prefix or "/luasocket-test" |
28 | cgiprefix = cgiprefix or "/luasocket-test-cgi" | 28 | cgiprefix = cgiprefix or "/luasocket-test-cgi" |
29 | index_file = "test/index.html" | 29 | index_file = "test/index.html" |