aboutsummaryrefslogtreecommitdiff
path: root/samples/forward.lua
diff options
context:
space:
mode:
Diffstat (limited to 'samples/forward.lua')
-rw-r--r--samples/forward.lua138
1 files changed, 71 insertions, 67 deletions
diff --git a/samples/forward.lua b/samples/forward.lua
index c3f0605..46f51c8 100644
--- a/samples/forward.lua
+++ b/samples/forward.lua
@@ -35,6 +35,13 @@ local sending = newset()
35-- context for connections and servers 35-- context for connections and servers
36local context = {} 36local context = {}
37 37
38function wait(who, what)
39 if what == "input" then receiving:insert(who)
40 else sending:insert(who) end
41 context[who].last = socket.gettime()
42 coroutine.yield()
43end
44
38-- initializes the forward server 45-- initializes the forward server
39function init() 46function init()
40 if table.getn(arg) < 1 then 47 if table.getn(arg) < 1 then
@@ -63,145 +70,142 @@ function init()
63end 70end
64 71
65-- starts a connection in a non-blocking way 72-- starts a connection in a non-blocking way
66function nbkcon(host, port) 73function connect(who, host, port)
67 local peer, err = socket.tcp() 74 who:settimeout(0.1)
68 if not peer then return nil, err end 75print("trying to connect peer", who, host, port)
69 peer:settimeout(0) 76 local ret, err = who:connect(host, port)
70 local ret, err = peer:connect(host, port) 77 if not ret and err == "timeout" then
71 if ret then return peer end 78print("got timeout, will wait", who)
72 if err ~= "timeout" then 79 wait(who, "output")
73 peer:close() 80 ret, err = who:connected()
74 return nil, err 81print("connection results arrived", who, ret, err)
82 end
83 if not ret then
84print("connection failed", who)
85 kick(who)
86 kick(context[who].peer)
87 else
88 return forward(who)
75 end 89 end
76 return peer
77end 90end
78 91
79-- gets rid of a client 92-- gets rid of a client and its peer
80function kick(who) 93function kick(who)
81 if context[who] then 94 if who and context[who] then
82 sending:remove(who) 95 sending:remove(who)
83 receiving:remove(who) 96 receiving:remove(who)
97 local peer = context[who].peer
84 context[who] = nil 98 context[who] = nil
85 who:close() 99 who:close()
86 end 100 end
87end 101end
88 102
89-- decides what to do with a thread based on coroutine return
90function route(who, status, what)
91 if status and what then
92 if what == "receiving" then receiving:insert(who) end
93 if what == "sending" then sending:insert(who) end
94 else kick(who) end
95end
96
97-- loops accepting connections and creating new threads to deal with them 103-- loops accepting connections and creating new threads to deal with them
98function accept(server) 104function accept(server)
99 while true do 105 while true do
100 -- accept a new connection and start a new coroutine to deal with it 106 -- accept a new connection and start a new coroutine to deal with it
101 local client = server:accept() 107 local client = server:accept()
108print("accepted ", client)
102 if client then 109 if client then
103 -- start a new connection, non-blockingly, to the forwarding address 110 -- create contexts for client and peer.
104 local ohost = context[server].ohost 111 local peer, err = socket.tcp()
105 local oport = context[server].oport
106 local peer = nbkcon(ohost, oport)
107 if peer then 112 if peer then
108 context[client] = { 113 context[client] = {
109 last = socket.gettime(), 114 last = socket.gettime(),
115 -- client goes straight to forwarding loop
110 thread = coroutine.create(forward), 116 thread = coroutine.create(forward),
111 peer = peer, 117 peer = peer,
112 } 118 }
113 -- make sure peer will be tested for writing in the next select
114 -- round, which means the connection attempt has finished
115 sending:insert(peer)
116 context[peer] = { 119 context[peer] = {
120 last = socket.gettime(),
117 peer = client, 121 peer = client,
118 thread = coroutine.create(chkcon), 122 -- peer first tries to connect to forwarding address
123 thread = coroutine.create(connect),
119 last = socket.gettime() 124 last = socket.gettime()
120 } 125 }
121 -- put both in non-blocking mode 126 -- resume peer and client so they can do their thing
122 client:settimeout(0) 127 local ohost = context[server].ohost
123 peer:settimeout(0) 128 local oport = context[server].oport
129 coroutine.resume(context[peer].thread, peer, ohost, oport)
130 coroutine.resume(context[client].thread, client)
124 else 131 else
125 -- otherwise just dump the client 132 print(err)
126 client:close() 133 client:close()
127 end 134 end
128 end 135 end
129 -- tell scheduler we are done for now 136 -- tell scheduler we are done for now
130 coroutine.yield("receiving") 137 wait(server, "input")
131 end 138 end
132end 139end
133 140
134-- forwards all data arriving to the appropriate peer 141-- forwards all data arriving to the appropriate peer
135function forward(who) 142function forward(who)
143print("starting to foward", who)
144 who:settimeout(0)
136 while true do 145 while true do
146 -- wait until we have something to read
147 wait(who, "input")
137 -- try to read as much as possible 148 -- try to read as much as possible
138 local data, rec_err, partial = who:receive("*a") 149 local data, rec_err, partial = who:receive("*a")
139 -- if we had an error other than timeout, abort 150 -- if we had an error other than timeout, abort
140 if rec_err and rec_err ~= "timeout" then return error(rec_err) end 151 if rec_err and rec_err ~= "timeout" then return kick(who) end
141 -- if we got a timeout, we probably have partial results to send 152 -- if we got a timeout, we probably have partial results to send
142 data = data or partial 153 data = data or partial
143 -- renew our timestamp so scheduler sees we are active
144 context[who].last = socket.gettime()
145 -- forward what we got right away 154 -- forward what we got right away
146 local peer = context[who].peer 155 local peer = context[who].peer
147 while true do 156 while true do
148 -- tell scheduler we need to wait until we can send something 157 -- tell scheduler we need to wait until we can send something
149 coroutine.yield("sending") 158 wait(who, "output")
150 local ret, snd_err 159 local ret, snd_err
151 local start = 0 160 local start = 0
152 ret, snd_err, start = peer:send(data, start+1) 161 ret, snd_err, start = peer:send(data, start+1)
153 if ret then break 162 if ret then break
154 elseif snd_err ~= "timeout" then return error(snd_err) end 163 elseif snd_err ~= "timeout" then return kick(who) end
155 -- renew our timestamp so scheduler sees we are active
156 context[who].last = socket.gettime()
157 end 164 end
158 -- if we are done receiving, we are done with this side of the 165 -- if we are done receiving, we are done
159 -- connection 166 if not rec_err then return kick(who) end
160 if not rec_err then return nil end
161 -- otherwise tell schedule we have to wait for more data to arrive
162 coroutine.yield("receiving")
163 end 167 end
164end 168end
165 169
166-- checks if a connection completed successfully and if it did, starts
167-- forwarding all data
168function chkcon(who)
169 local ret, err = who:connected()
170 if ret then
171 receiving:insert(context[who].peer)
172 context[who].last = socket.gettime()
173 coroutine.yield("receiving")
174 return forward(who)
175 else return error(err) end
176end
177
178-- loop waiting until something happens, restarting the thread to deal with 170-- loop waiting until something happens, restarting the thread to deal with
179-- what happened, and routing it to wait until something else happens 171-- what happened, and routing it to wait until something else happens
180function go() 172function go()
181 while true do 173 while true do
174print("will select for reading")
175for i,v in ipairs(receiving) do
176 print(i, v)
177end
178print("will select for sending")
179for i,v in ipairs(sending) do
180 print(i, v)
181end
182 -- check which sockets are interesting and act on them 182 -- check which sockets are interesting and act on them
183 readable, writable = socket.select(receiving, sending, 3) 183 readable, writable = socket.select(receiving, sending, 3)
184 -- for all readable connections, resume its thread and route it 184print("was readable")
185for i,v in ipairs(readable) do
186 print(i, v)
187end
188print("was writable")
189for i,v in ipairs(writable) do
190 print(i, v)
191end
192 -- for all readable connections, resume its thread
185 for _, who in ipairs(readable) do 193 for _, who in ipairs(readable) do
186 receiving:remove(who) 194 receiving:remove(who)
187 if context[who] then 195 coroutine.resume(context[who].thread, who)
188 route(who, coroutine.resume(context[who].thread, who))
189 end
190 end 196 end
191 -- for all writable connections, do the same 197 -- for all writable connections, do the same
192 for _, who in ipairs(writable) do 198 for _, who in ipairs(writable) do
193 sending:remove(who) 199 sending:remove(who)
194 if context[who] then 200 coroutine.resume(context[who].thread, who)
195 route(who, coroutine.resume(context[who].thread, who))
196 end
197 end 201 end
198 -- put all inactive threads in death row 202 -- put all inactive threads in death row
199 local now = socket.gettime() 203 local now = socket.gettime()
200 local deathrow 204 local deathrow
201 for who, data in pairs(context) do 205 for who, data in pairs(context) do
202 if data.last then 206 if data.peer then
203 if now - data.last > TIMEOUT then 207 if now - data.last > TIMEOUT then
204 -- only create table if someone is doomed 208 -- only create table if at least one is doomed
205 deathrow = deathrow or {} 209 deathrow = deathrow or {}
206 deathrow[who] = true 210 deathrow[who] = true
207 end 211 end