aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--samples/forward.lua229
1 files changed, 229 insertions, 0 deletions
diff --git a/samples/forward.lua b/samples/forward.lua
new file mode 100644
index 0000000..de651b4
--- /dev/null
+++ b/samples/forward.lua
@@ -0,0 +1,229 @@
1-- load our favourite library
2local socket = require"socket"
3-- timeout before an inactive thread is kicked
4local TIMEOUT = 10
5-- local address to bind to
6local ihost, iport = arg[1] or "localhost", arg[2] or 8080
7-- address to forward all data to
8local ohost, oport = arg[3] or "localhost", arg[4] or 3128
9
10-- creates a new set data structure
11function newset()
12 local reverse = {}
13 local set = {}
14 return setmetatable(set, {__index = {
15 insert = function(set, value)
16 if not reverse[value] then
17 table.insert(set, value)
18 reverse[value] = table.getn(set)
19 end
20 end,
21 remove = function(set, value)
22 local index = reverse[value]
23 if index then
24 reverse[value] = nil
25 local top = table.remove(set)
26 if top ~= value then
27 reverse[top] = index
28 set[index] = top
29 end
30 end
31 end
32 }})
33end
34
35local receiving = newset()
36local sending = newset()
37local context = {}
38
39-- starts a non-blocking connect
40function nconnect(host, port)
41 local peer, err = socket.tcp()
42 if not peer then return nil, err end
43 peer:settimeout(0)
44 local ret, err = peer:connect(host, port)
45 if ret then return peer end
46 if err ~= "timeout" then
47 peer:close()
48 return nil, err
49 end
50 return peer
51end
52
53-- gets rid of a client
54function kick(who)
55if who == server then error("FUDEU") end
56 if context[who] then
57 sending:remove(who)
58 receiving:remove(who)
59 context[who] = nil
60 who:close()
61 end
62end
63
64-- decides what to do with a thread based on coroutine return
65function route(who, status, what)
66print(who, status, what)
67 if status and what then
68 if what == "receiving" then receiving:insert(who) end
69 if what == "sending" then sending:insert(who) end
70 else kick(who) end
71end
72
73-- loops accepting connections and creating new threads to deal with them
74function accept(server)
75 while true do
76print(server, "accepting a new client")
77 -- accept a new connection and start a new coroutine to deal with it
78 local client = server:accept()
79 if client then
80 -- start a new connection, non-blockingly, to the forwarding address
81 local peer = nconnect(ohost, oport)
82 if peer then
83 context[client] = {
84 last = socket.gettime(),
85 thread = coroutine.create(forward),
86 peer = peer,
87 }
88 -- make sure peer will be tested for writing in the next select
89 -- round, which means the connection attempt has finished
90 sending:insert(peer)
91 context[peer] = {
92 peer = client,
93 thread = coroutine.create(check),
94 last = socket.gettime()
95 }
96 -- put both in non-blocking mode
97 client:settimeout(0)
98 peer:settimeout(0)
99 else
100 -- otherwise just dump the client
101 client:close()
102 end
103 end
104 -- tell scheduler we are done for now
105 coroutine.yield("receiving")
106 end
107end
108
109-- forwards all data arriving to the appropriate peer
110function forward(who)
111 while true do
112print(who, "getting data")
113 -- try to read as much as possible
114 local data, rec_err, partial = who:receive("*a")
115 -- if we had an error other than timeout, abort
116 if rec_err and rec_err ~= "timeout" then return error(rec_err) end
117 -- if we got a timeout, we probably have partial results to send
118 data = data or partial
119print(who, " got ", string.len(data))
120 -- renew our timestamp so scheduler sees we are active
121 context[who].last = socket.gettime()
122 -- forward what we got right away
123 local peer = context[who].peer
124 while true do
125 -- tell scheduler we need to wait until we can send something
126 coroutine.yield("sending")
127 local ret, snd_err
128 local start = 0
129print(who, "sending data")
130 ret, snd_err, start = peer:send(data, start+1)
131 if ret then break
132 elseif snd_err ~= "timeout" then return error(snd_err) end
133 -- renew our timestamp so scheduler sees we are active
134 context[who].last = socket.gettime()
135 end
136 -- if we are done receiving, we are done with this side of the
137 -- connection
138 if not rec_err then return nil end
139 -- otherwise tell schedule we have to wait for more data to arrive
140 coroutine.yield("receiving")
141 end
142end
143
144-- checks if a connection completed successfully and if it did, starts
145-- forwarding all data
146function check(who)
147 local ret, err = who:connected()
148 if ret then
149print(who, "connection completed")
150 receiving:insert(context[who].peer)
151 context[who].last = socket.gettime()
152print(who, "yielding until there is input data")
153 coroutine.yield("receiving")
154 return forward(who)
155 else return error(err) end
156end
157
158-- initializes the forward server
159function init()
160 -- socket sets to test for events
161 -- create our server socket
162 server = assert(socket.bind(ihost, iport))
163 server:settimeout(0.1) -- we don't want to be killed by bad luck
164 -- we initially
165 receiving:insert(server)
166 context[server] = { thread = coroutine.create(accept) }
167end
168
169-- loop waiting until something happens, restarting the thread to deal with
170-- what happened, and routing it to wait until something else happens
171function go()
172 while true do
173print("will select for readability")
174for i,v in ipairs(receiving) do
175 print(i, v)
176end
177print("will select for writability")
178for i,v in ipairs(sending) do
179 print(i, v)
180end
181 -- check which sockets are interesting and act on them
182 readable, writable = socket.select(receiving, sending, 3)
183print("returned as readable")
184for i,v in ipairs(readable) do
185 print(i, v)
186end
187print("returned as writable")
188for i,v in ipairs(writable) do
189 print(i, v)
190end
191 -- for all readable connections, resume its thread and route it
192 for _, who in ipairs(readable) do
193 receiving:remove(who)
194 if context[who] then
195 route(who, coroutine.resume(context[who].thread, who))
196 end
197 end
198 -- for all writable connections, do the same
199 for _, who in ipairs(writable) do
200 sending:remove(who)
201 if context[who] then
202 route(who, coroutine.resume(context[who].thread, who))
203 end
204 end
205 -- put all inactive threads in death row
206 local now = socket.gettime()
207 local deathrow
208 for who, data in pairs(context) do
209 if data.last then
210print("hung for" , now - data.last, who)
211 if now - data.last > TIMEOUT then
212 -- only create table if someone is doomed
213 deathrow = deathrow or {}
214 deathrow[who] = true
215 end
216 end
217 end
218 -- finally kick everyone in deathrow
219 if deathrow then
220print("in death row")
221for i,v in pairs(deathrow) do
222 print(i, v)
223end
224 for who in pairs(deathrow) do kick(who) end
225 end
226 end
227end
228
229go(init())