diff options
Diffstat (limited to 'samples/forward.lua')
-rw-r--r-- | samples/forward.lua | 138 |
1 files changed, 71 insertions, 67 deletions
diff --git a/samples/forward.lua b/samples/forward.lua index c3f0605..46f51c8 100644 --- a/samples/forward.lua +++ b/samples/forward.lua | |||
@@ -35,6 +35,13 @@ local sending = newset() | |||
35 | -- context for connections and servers | 35 | -- context for connections and servers |
36 | local context = {} | 36 | local context = {} |
37 | 37 | ||
38 | function wait(who, what) | ||
39 | if what == "input" then receiving:insert(who) | ||
40 | else sending:insert(who) end | ||
41 | context[who].last = socket.gettime() | ||
42 | coroutine.yield() | ||
43 | end | ||
44 | |||
38 | -- initializes the forward server | 45 | -- initializes the forward server |
39 | function init() | 46 | function init() |
40 | if table.getn(arg) < 1 then | 47 | if table.getn(arg) < 1 then |
@@ -63,145 +70,142 @@ function init() | |||
63 | end | 70 | end |
64 | 71 | ||
65 | -- starts a connection in a non-blocking way | 72 | -- starts a connection in a non-blocking way |
66 | function nbkcon(host, port) | 73 | function connect(who, host, port) |
67 | local peer, err = socket.tcp() | 74 | who:settimeout(0.1) |
68 | if not peer then return nil, err end | 75 | print("trying to connect peer", who, host, port) |
69 | peer:settimeout(0) | 76 | local ret, err = who:connect(host, port) |
70 | local ret, err = peer:connect(host, port) | 77 | if not ret and err == "timeout" then |
71 | if ret then return peer end | 78 | print("got timeout, will wait", who) |
72 | if err ~= "timeout" then | 79 | wait(who, "output") |
73 | peer:close() | 80 | ret, err = who:connected() |
74 | return nil, err | 81 | print("connection results arrived", who, ret, err) |
82 | end | ||
83 | if not ret then | ||
84 | print("connection failed", who) | ||
85 | kick(who) | ||
86 | kick(context[who].peer) | ||
87 | else | ||
88 | return forward(who) | ||
75 | end | 89 | end |
76 | return peer | ||
77 | end | 90 | end |
78 | 91 | ||
79 | -- gets rid of a client | 92 | -- gets rid of a client and its peer |
80 | function kick(who) | 93 | function kick(who) |
81 | if context[who] then | 94 | if who and context[who] then |
82 | sending:remove(who) | 95 | sending:remove(who) |
83 | receiving:remove(who) | 96 | receiving:remove(who) |
97 | local peer = context[who].peer | ||
84 | context[who] = nil | 98 | context[who] = nil |
85 | who:close() | 99 | who:close() |
86 | end | 100 | end |
87 | end | 101 | end |
88 | 102 | ||
89 | -- decides what to do with a thread based on coroutine return | ||
90 | function route(who, status, what) | ||
91 | if status and what then | ||
92 | if what == "receiving" then receiving:insert(who) end | ||
93 | if what == "sending" then sending:insert(who) end | ||
94 | else kick(who) end | ||
95 | end | ||
96 | |||
97 | -- loops accepting connections and creating new threads to deal with them | 103 | -- loops accepting connections and creating new threads to deal with them |
98 | function accept(server) | 104 | function accept(server) |
99 | while true do | 105 | while true do |
100 | -- accept a new connection and start a new coroutine to deal with it | 106 | -- accept a new connection and start a new coroutine to deal with it |
101 | local client = server:accept() | 107 | local client = server:accept() |
108 | print("accepted ", client) | ||
102 | if client then | 109 | if client then |
103 | -- start a new connection, non-blockingly, to the forwarding address | 110 | -- create contexts for client and peer. |
104 | local ohost = context[server].ohost | 111 | local peer, err = socket.tcp() |
105 | local oport = context[server].oport | ||
106 | local peer = nbkcon(ohost, oport) | ||
107 | if peer then | 112 | if peer then |
108 | context[client] = { | 113 | context[client] = { |
109 | last = socket.gettime(), | 114 | last = socket.gettime(), |
115 | -- client goes straight to forwarding loop | ||
110 | thread = coroutine.create(forward), | 116 | thread = coroutine.create(forward), |
111 | peer = peer, | 117 | peer = peer, |
112 | } | 118 | } |
113 | -- make sure peer will be tested for writing in the next select | ||
114 | -- round, which means the connection attempt has finished | ||
115 | sending:insert(peer) | ||
116 | context[peer] = { | 119 | context[peer] = { |
120 | last = socket.gettime(), | ||
117 | peer = client, | 121 | peer = client, |
118 | thread = coroutine.create(chkcon), | 122 | -- peer first tries to connect to forwarding address |
123 | thread = coroutine.create(connect), | ||
119 | last = socket.gettime() | 124 | last = socket.gettime() |
120 | } | 125 | } |
121 | -- put both in non-blocking mode | 126 | -- resume peer and client so they can do their thing |
122 | client:settimeout(0) | 127 | local ohost = context[server].ohost |
123 | peer:settimeout(0) | 128 | local oport = context[server].oport |
129 | coroutine.resume(context[peer].thread, peer, ohost, oport) | ||
130 | coroutine.resume(context[client].thread, client) | ||
124 | else | 131 | else |
125 | -- otherwise just dump the client | 132 | print(err) |
126 | client:close() | 133 | client:close() |
127 | end | 134 | end |
128 | end | 135 | end |
129 | -- tell scheduler we are done for now | 136 | -- tell scheduler we are done for now |
130 | coroutine.yield("receiving") | 137 | wait(server, "input") |
131 | end | 138 | end |
132 | end | 139 | end |
133 | 140 | ||
134 | -- forwards all data arriving to the appropriate peer | 141 | -- forwards all data arriving to the appropriate peer |
135 | function forward(who) | 142 | function forward(who) |
143 | print("starting to foward", who) | ||
144 | who:settimeout(0) | ||
136 | while true do | 145 | while true do |
146 | -- wait until we have something to read | ||
147 | wait(who, "input") | ||
137 | -- try to read as much as possible | 148 | -- try to read as much as possible |
138 | local data, rec_err, partial = who:receive("*a") | 149 | local data, rec_err, partial = who:receive("*a") |
139 | -- if we had an error other than timeout, abort | 150 | -- if we had an error other than timeout, abort |
140 | if rec_err and rec_err ~= "timeout" then return error(rec_err) end | 151 | if rec_err and rec_err ~= "timeout" then return kick(who) end |
141 | -- if we got a timeout, we probably have partial results to send | 152 | -- if we got a timeout, we probably have partial results to send |
142 | data = data or partial | 153 | data = data or partial |
143 | -- renew our timestamp so scheduler sees we are active | ||
144 | context[who].last = socket.gettime() | ||
145 | -- forward what we got right away | 154 | -- forward what we got right away |
146 | local peer = context[who].peer | 155 | local peer = context[who].peer |
147 | while true do | 156 | while true do |
148 | -- tell scheduler we need to wait until we can send something | 157 | -- tell scheduler we need to wait until we can send something |
149 | coroutine.yield("sending") | 158 | wait(who, "output") |
150 | local ret, snd_err | 159 | local ret, snd_err |
151 | local start = 0 | 160 | local start = 0 |
152 | ret, snd_err, start = peer:send(data, start+1) | 161 | ret, snd_err, start = peer:send(data, start+1) |
153 | if ret then break | 162 | if ret then break |
154 | elseif snd_err ~= "timeout" then return error(snd_err) end | 163 | elseif snd_err ~= "timeout" then return kick(who) end |
155 | -- renew our timestamp so scheduler sees we are active | ||
156 | context[who].last = socket.gettime() | ||
157 | end | 164 | end |
158 | -- if we are done receiving, we are done with this side of the | 165 | -- if we are done receiving, we are done |
159 | -- connection | 166 | if not rec_err then return kick(who) end |
160 | if not rec_err then return nil end | ||
161 | -- otherwise tell schedule we have to wait for more data to arrive | ||
162 | coroutine.yield("receiving") | ||
163 | end | 167 | end |
164 | end | 168 | end |
165 | 169 | ||
166 | -- checks if a connection completed successfully and if it did, starts | ||
167 | -- forwarding all data | ||
168 | function chkcon(who) | ||
169 | local ret, err = who:connected() | ||
170 | if ret then | ||
171 | receiving:insert(context[who].peer) | ||
172 | context[who].last = socket.gettime() | ||
173 | coroutine.yield("receiving") | ||
174 | return forward(who) | ||
175 | else return error(err) end | ||
176 | end | ||
177 | |||
178 | -- loop waiting until something happens, restarting the thread to deal with | 170 | -- loop waiting until something happens, restarting the thread to deal with |
179 | -- what happened, and routing it to wait until something else happens | 171 | -- what happened, and routing it to wait until something else happens |
180 | function go() | 172 | function go() |
181 | while true do | 173 | while true do |
174 | print("will select for reading") | ||
175 | for i,v in ipairs(receiving) do | ||
176 | print(i, v) | ||
177 | end | ||
178 | print("will select for sending") | ||
179 | for i,v in ipairs(sending) do | ||
180 | print(i, v) | ||
181 | end | ||
182 | -- check which sockets are interesting and act on them | 182 | -- check which sockets are interesting and act on them |
183 | readable, writable = socket.select(receiving, sending, 3) | 183 | readable, writable = socket.select(receiving, sending, 3) |
184 | -- for all readable connections, resume its thread and route it | 184 | print("was readable") |
185 | for i,v in ipairs(readable) do | ||
186 | print(i, v) | ||
187 | end | ||
188 | print("was writable") | ||
189 | for i,v in ipairs(writable) do | ||
190 | print(i, v) | ||
191 | end | ||
192 | -- for all readable connections, resume its thread | ||
185 | for _, who in ipairs(readable) do | 193 | for _, who in ipairs(readable) do |
186 | receiving:remove(who) | 194 | receiving:remove(who) |
187 | if context[who] then | 195 | coroutine.resume(context[who].thread, who) |
188 | route(who, coroutine.resume(context[who].thread, who)) | ||
189 | end | ||
190 | end | 196 | end |
191 | -- for all writable connections, do the same | 197 | -- for all writable connections, do the same |
192 | for _, who in ipairs(writable) do | 198 | for _, who in ipairs(writable) do |
193 | sending:remove(who) | 199 | sending:remove(who) |
194 | if context[who] then | 200 | coroutine.resume(context[who].thread, who) |
195 | route(who, coroutine.resume(context[who].thread, who)) | ||
196 | end | ||
197 | end | 201 | end |
198 | -- put all inactive threads in death row | 202 | -- put all inactive threads in death row |
199 | local now = socket.gettime() | 203 | local now = socket.gettime() |
200 | local deathrow | 204 | local deathrow |
201 | for who, data in pairs(context) do | 205 | for who, data in pairs(context) do |
202 | if data.last then | 206 | if data.peer then |
203 | if now - data.last > TIMEOUT then | 207 | if now - data.last > TIMEOUT then |
204 | -- only create table if someone is doomed | 208 | -- only create table if at least one is doomed |
205 | deathrow = deathrow or {} | 209 | deathrow = deathrow or {} |
206 | deathrow[who] = true | 210 | deathrow[who] = true |
207 | end | 211 | end |