diff options
-rw-r--r-- | samples/forward.lua | 229 |
1 files changed, 229 insertions, 0 deletions
diff --git a/samples/forward.lua b/samples/forward.lua new file mode 100644 index 0000000..de651b4 --- /dev/null +++ b/samples/forward.lua | |||
@@ -0,0 +1,229 @@ | |||
1 | -- load our favourite library | ||
2 | local socket = require"socket" | ||
3 | -- timeout before an inactive thread is kicked | ||
4 | local TIMEOUT = 10 | ||
5 | -- local address to bind to | ||
6 | local ihost, iport = arg[1] or "localhost", arg[2] or 8080 | ||
7 | -- address to forward all data to | ||
8 | local ohost, oport = arg[3] or "localhost", arg[4] or 3128 | ||
9 | |||
10 | -- creates a new set data structure | ||
11 | function newset() | ||
12 | local reverse = {} | ||
13 | local set = {} | ||
14 | return setmetatable(set, {__index = { | ||
15 | insert = function(set, value) | ||
16 | if not reverse[value] then | ||
17 | table.insert(set, value) | ||
18 | reverse[value] = table.getn(set) | ||
19 | end | ||
20 | end, | ||
21 | remove = function(set, value) | ||
22 | local index = reverse[value] | ||
23 | if index then | ||
24 | reverse[value] = nil | ||
25 | local top = table.remove(set) | ||
26 | if top ~= value then | ||
27 | reverse[top] = index | ||
28 | set[index] = top | ||
29 | end | ||
30 | end | ||
31 | end | ||
32 | }}) | ||
33 | end | ||
34 | |||
35 | local receiving = newset() | ||
36 | local sending = newset() | ||
37 | local context = {} | ||
38 | |||
39 | -- starts a non-blocking connect | ||
40 | function nconnect(host, port) | ||
41 | local peer, err = socket.tcp() | ||
42 | if not peer then return nil, err end | ||
43 | peer:settimeout(0) | ||
44 | local ret, err = peer:connect(host, port) | ||
45 | if ret then return peer end | ||
46 | if err ~= "timeout" then | ||
47 | peer:close() | ||
48 | return nil, err | ||
49 | end | ||
50 | return peer | ||
51 | end | ||
52 | |||
53 | -- gets rid of a client | ||
54 | function kick(who) | ||
55 | if who == server then error("FUDEU") end | ||
56 | if context[who] then | ||
57 | sending:remove(who) | ||
58 | receiving:remove(who) | ||
59 | context[who] = nil | ||
60 | who:close() | ||
61 | end | ||
62 | end | ||
63 | |||
64 | -- decides what to do with a thread based on coroutine return | ||
65 | function route(who, status, what) | ||
66 | print(who, status, what) | ||
67 | if status and what then | ||
68 | if what == "receiving" then receiving:insert(who) end | ||
69 | if what == "sending" then sending:insert(who) end | ||
70 | else kick(who) end | ||
71 | end | ||
72 | |||
73 | -- loops accepting connections and creating new threads to deal with them | ||
74 | function accept(server) | ||
75 | while true do | ||
76 | print(server, "accepting a new client") | ||
77 | -- accept a new connection and start a new coroutine to deal with it | ||
78 | local client = server:accept() | ||
79 | if client then | ||
80 | -- start a new connection, non-blockingly, to the forwarding address | ||
81 | local peer = nconnect(ohost, oport) | ||
82 | if peer then | ||
83 | context[client] = { | ||
84 | last = socket.gettime(), | ||
85 | thread = coroutine.create(forward), | ||
86 | peer = peer, | ||
87 | } | ||
88 | -- make sure peer will be tested for writing in the next select | ||
89 | -- round, which means the connection attempt has finished | ||
90 | sending:insert(peer) | ||
91 | context[peer] = { | ||
92 | peer = client, | ||
93 | thread = coroutine.create(check), | ||
94 | last = socket.gettime() | ||
95 | } | ||
96 | -- put both in non-blocking mode | ||
97 | client:settimeout(0) | ||
98 | peer:settimeout(0) | ||
99 | else | ||
100 | -- otherwise just dump the client | ||
101 | client:close() | ||
102 | end | ||
103 | end | ||
104 | -- tell scheduler we are done for now | ||
105 | coroutine.yield("receiving") | ||
106 | end | ||
107 | end | ||
108 | |||
109 | -- forwards all data arriving to the appropriate peer | ||
110 | function forward(who) | ||
111 | while true do | ||
112 | print(who, "getting data") | ||
113 | -- try to read as much as possible | ||
114 | local data, rec_err, partial = who:receive("*a") | ||
115 | -- if we had an error other than timeout, abort | ||
116 | if rec_err and rec_err ~= "timeout" then return error(rec_err) end | ||
117 | -- if we got a timeout, we probably have partial results to send | ||
118 | data = data or partial | ||
119 | print(who, " got ", string.len(data)) | ||
120 | -- renew our timestamp so scheduler sees we are active | ||
121 | context[who].last = socket.gettime() | ||
122 | -- forward what we got right away | ||
123 | local peer = context[who].peer | ||
124 | while true do | ||
125 | -- tell scheduler we need to wait until we can send something | ||
126 | coroutine.yield("sending") | ||
127 | local ret, snd_err | ||
128 | local start = 0 | ||
129 | print(who, "sending data") | ||
130 | ret, snd_err, start = peer:send(data, start+1) | ||
131 | if ret then break | ||
132 | elseif snd_err ~= "timeout" then return error(snd_err) end | ||
133 | -- renew our timestamp so scheduler sees we are active | ||
134 | context[who].last = socket.gettime() | ||
135 | end | ||
136 | -- if we are done receiving, we are done with this side of the | ||
137 | -- connection | ||
138 | if not rec_err then return nil end | ||
139 | -- otherwise tell schedule we have to wait for more data to arrive | ||
140 | coroutine.yield("receiving") | ||
141 | end | ||
142 | end | ||
143 | |||
144 | -- checks if a connection completed successfully and if it did, starts | ||
145 | -- forwarding all data | ||
146 | function check(who) | ||
147 | local ret, err = who:connected() | ||
148 | if ret then | ||
149 | print(who, "connection completed") | ||
150 | receiving:insert(context[who].peer) | ||
151 | context[who].last = socket.gettime() | ||
152 | print(who, "yielding until there is input data") | ||
153 | coroutine.yield("receiving") | ||
154 | return forward(who) | ||
155 | else return error(err) end | ||
156 | end | ||
157 | |||
158 | -- initializes the forward server | ||
159 | function init() | ||
160 | -- socket sets to test for events | ||
161 | -- create our server socket | ||
162 | server = assert(socket.bind(ihost, iport)) | ||
163 | server:settimeout(0.1) -- we don't want to be killed by bad luck | ||
164 | -- we initially | ||
165 | receiving:insert(server) | ||
166 | context[server] = { thread = coroutine.create(accept) } | ||
167 | end | ||
168 | |||
169 | -- loop waiting until something happens, restarting the thread to deal with | ||
170 | -- what happened, and routing it to wait until something else happens | ||
171 | function go() | ||
172 | while true do | ||
173 | print("will select for readability") | ||
174 | for i,v in ipairs(receiving) do | ||
175 | print(i, v) | ||
176 | end | ||
177 | print("will select for writability") | ||
178 | for i,v in ipairs(sending) do | ||
179 | print(i, v) | ||
180 | end | ||
181 | -- check which sockets are interesting and act on them | ||
182 | readable, writable = socket.select(receiving, sending, 3) | ||
183 | print("returned as readable") | ||
184 | for i,v in ipairs(readable) do | ||
185 | print(i, v) | ||
186 | end | ||
187 | print("returned as writable") | ||
188 | for i,v in ipairs(writable) do | ||
189 | print(i, v) | ||
190 | end | ||
191 | -- for all readable connections, resume its thread and route it | ||
192 | for _, who in ipairs(readable) do | ||
193 | receiving:remove(who) | ||
194 | if context[who] then | ||
195 | route(who, coroutine.resume(context[who].thread, who)) | ||
196 | end | ||
197 | end | ||
198 | -- for all writable connections, do the same | ||
199 | for _, who in ipairs(writable) do | ||
200 | sending:remove(who) | ||
201 | if context[who] then | ||
202 | route(who, coroutine.resume(context[who].thread, who)) | ||
203 | end | ||
204 | end | ||
205 | -- put all inactive threads in death row | ||
206 | local now = socket.gettime() | ||
207 | local deathrow | ||
208 | for who, data in pairs(context) do | ||
209 | if data.last then | ||
210 | print("hung for" , now - data.last, who) | ||
211 | if now - data.last > TIMEOUT then | ||
212 | -- only create table if someone is doomed | ||
213 | deathrow = deathrow or {} | ||
214 | deathrow[who] = true | ||
215 | end | ||
216 | end | ||
217 | end | ||
218 | -- finally kick everyone in deathrow | ||
219 | if deathrow then | ||
220 | print("in death row") | ||
221 | for i,v in pairs(deathrow) do | ||
222 | print(i, v) | ||
223 | end | ||
224 | for who in pairs(deathrow) do kick(who) end | ||
225 | end | ||
226 | end | ||
227 | end | ||
228 | |||
229 | go(init()) | ||