aboutsummaryrefslogtreecommitdiff
path: root/samples/forward.lua
diff options
context:
space:
mode:
authorDiego Nehab <diego@tecgraf.puc-rio.br>2005-11-20 08:56:19 +0000
committerDiego Nehab <diego@tecgraf.puc-rio.br>2005-11-20 08:56:19 +0000
commit5e09779c7f6b1710150d5a0f12d86ded7ede75c6 (patch)
treeffd2e2e7d918cc30015c89bd14122aa8cadb1546 /samples/forward.lua
parentf20f4889bfe5a02cd9b77868b90cc8042352176a (diff)
downloadluasocket-5e09779c7f6b1710150d5a0f12d86ded7ede75c6.tar.gz
luasocket-5e09779c7f6b1710150d5a0f12d86ded7ede75c6.tar.bz2
luasocket-5e09779c7f6b1710150d5a0f12d86ded7ede75c6.zip
In pre release mode!
Diffstat (limited to 'samples/forward.lua')
-rw-r--r--samples/forward.lua207
1 files changed, 0 insertions, 207 deletions
diff --git a/samples/forward.lua b/samples/forward.lua
deleted file mode 100644
index 548a753..0000000
--- a/samples/forward.lua
+++ /dev/null
@@ -1,207 +0,0 @@
1-- load our favourite library
2local socket = require"socket"
3
4-- creates a new set data structure
5function newset()
6 local reverse = {}
7 local set = {}
8 return setmetatable(set, {__index = {
9 insert = function(set, value)
10 if not reverse[value] then
11 table.insert(set, value)
12 reverse[value] = table.getn(set)
13 end
14 end,
15 remove = function(set, value)
16 local index = reverse[value]
17 if index then
18 reverse[value] = nil
19 local top = table.remove(set)
20 if top ~= value then
21 reverse[top] = index
22 set[index] = top
23 end
24 end
25 end
26 }})
27end
28
29-- timeout before an inactive thread is kicked
30local TIMEOUT = 10
31-- set of connections waiting to receive data
32local receiving = newset(1)
33-- set of sockets waiting to send data
34local sending = newset()
35-- context for connections and servers
36local context = {}
37
38function wait(who, what)
39 if what == "input" then receiving:insert(who)
40 else sending:insert(who) end
41 context[who].last = socket.gettime()
42 coroutine.yield()
43end
44
45-- initializes the forward server
46function init()
47 if table.getn(arg) < 1 then
48 print("Usage")
49 print(" lua forward.lua <iport:ohost:oport> ...")
50 os.exit(1)
51 end
52 -- for each tunnel, start a new server socket
53 for i, v in ipairs(arg) do
54 -- capture forwarding parameters
55 local iport, ohost, oport =
56 socket.skip(2, string.find(v, "([^:]+):([^:]+):([^:]+)"))
57 assert(iport, "invalid arguments")
58 -- create our server socket
59 local server = assert(socket.bind("*", iport))
60 server:settimeout(0) -- we don't want to be killed by bad luck
61 -- make sure server is tested for readability
62 receiving:insert(server)
63 -- add server context
64 context[server] = {
65 thread = coroutine.create(accept),
66 ohost = ohost,
67 oport = oport
68 }
69 end
70end
71
72-- starts a connection in a non-blocking way
73function connect(who, host, port)
74 who:settimeout(0)
75 local ret, err = who:connect(host, port)
76 if not ret and err == "timeout" then
77 wait(who, "output")
78 ret, err = who:connect(host, port)
79 if not ret and err ~= "already connected" then
80 kick(context[who].peer)
81 kick(who)
82 return
83 end
84 end
85 return forward(who)
86end
87
88-- gets rid of a client
89function kick(who)
90 if who then
91 sending:remove(who)
92 receiving:remove(who)
93 who:close()
94 context[who] = nil
95 end
96end
97
98-- loops accepting connections and creating new threads to deal with them
99function accept(server)
100 while true do
101 -- accept a new connection and start a new coroutine to deal with it
102 local client = server:accept()
103 if client then
104 -- create contexts for client and peer.
105 local peer, err = socket.tcp()
106 if peer then
107 context[client] = {
108 last = socket.gettime(),
109 -- client goes straight to forwarding loop
110 thread = coroutine.create(forward),
111 peer = peer,
112 }
113 context[peer] = {
114 last = socket.gettime(),
115 peer = client,
116 -- peer first tries to connect to forwarding address
117 thread = coroutine.create(connect),
118 last = socket.gettime()
119 }
120 -- resume peer and client so they can do their thing
121 local ohost = context[server].ohost
122 local oport = context[server].oport
123 coroutine.resume(context[peer].thread, peer, ohost, oport)
124 coroutine.resume(context[client].thread, client)
125 else
126 print(err)
127 client:close()
128 end
129 end
130 -- tell scheduler we are done for now
131 wait(server, "input")
132 end
133end
134
135-- forwards all data arriving to the appropriate peer
136function forward(who)
137 who:settimeout(0)
138 while true do
139 -- wait until we have something to read
140 wait(who, "input")
141 -- try to read as much as possible
142 local data, rec_err, partial = who:receive("*a")
143 -- if we had an error other than timeout, abort
144 if rec_err and rec_err ~= "timeout" then return kick(who) end
145 -- if we got a timeout, we probably have partial results to send
146 data = data or partial
147 -- forward what we got right away
148 local peer = context[who].peer
149 while true do
150 -- tell scheduler we need to wait until we can send something
151 wait(who, "output")
152 local ret, snd_err
153 local start = 0
154 ret, snd_err, start = peer:send(data, start+1)
155 if ret then break
156 elseif snd_err ~= "timeout" then return kick(who) end
157 end
158 -- if we are done receiving, we are done
159 if not rec_err then
160 kick(who)
161 kick(peer)
162 break
163 end
164 end
165end
166
167-- loop waiting until something happens, restarting the thread to deal with
168-- what happened, and routing it to wait until something else happens
169function go()
170 while true do
171 -- check which sockets are interesting and act on them
172 readable, writable = socket.select(receiving, sending)
173 -- for all readable connections, resume its thread
174 for _, who in ipairs(readable) do
175 if context[who] then
176 receiving:remove(who)
177 coroutine.resume(context[who].thread, who)
178 end
179 end
180 -- for all writable connections, do the same
181 for _, who in ipairs(writable) do
182 if context[who] then
183 sending:remove(who)
184 coroutine.resume(context[who].thread, who)
185 end
186 end
187 -- put all inactive threads in death row
188 local now = socket.gettime()
189 local deathrow
190 for who, data in pairs(context) do
191 if data.peer then
192 if now - data.last > TIMEOUT then
193 -- only create table if at least one is doomed
194 deathrow = deathrow or {}
195 deathrow[who] = true
196 end
197 end
198 end
199 -- finally kick everyone in deathrow
200 if deathrow then
201 for who in pairs(deathrow) do kick(who) end
202 end
203 end
204end
205
206init()
207go()