From e57f9e9964ac16b1fd09028ea533457f3029d296 Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Fri, 11 Mar 2005 00:20:21 +0000 Subject: Apparently, non-blocking connect doesn't work on windows if you use 0 timeout in the select call... --- samples/forward.lua | 138 +++++++++++++++++++++++++++------------------------- 1 file changed, 71 insertions(+), 67 deletions(-) (limited to 'samples') diff --git a/samples/forward.lua b/samples/forward.lua index c3f0605..46f51c8 100644 --- a/samples/forward.lua +++ b/samples/forward.lua @@ -35,6 +35,13 @@ local sending = newset() -- context for connections and servers local context = {} +function wait(who, what) + if what == "input" then receiving:insert(who) + else sending:insert(who) end + context[who].last = socket.gettime() + coroutine.yield() +end + -- initializes the forward server function init() if table.getn(arg) < 1 then @@ -63,145 +70,142 @@ function init() end -- starts a connection in a non-blocking way -function nbkcon(host, port) - local peer, err = socket.tcp() - if not peer then return nil, err end - peer:settimeout(0) - local ret, err = peer:connect(host, port) - if ret then return peer end - if err ~= "timeout" then - peer:close() - return nil, err +function connect(who, host, port) + who:settimeout(0.1) +print("trying to connect peer", who, host, port) + local ret, err = who:connect(host, port) + if not ret and err == "timeout" then +print("got timeout, will wait", who) + wait(who, "output") + ret, err = who:connected() +print("connection results arrived", who, ret, err) + end + if not ret then +print("connection failed", who) + kick(who) + kick(context[who].peer) + else + return forward(who) end - return peer end --- gets rid of a client +-- gets rid of a client and its peer function kick(who) - if context[who] then + if who and context[who] then sending:remove(who) receiving:remove(who) + local peer = context[who].peer context[who] = nil who:close() end end --- decides what to do with a thread based on coroutine return -function route(who, status, what) - if status and what then - if what == "receiving" then receiving:insert(who) end - if what == "sending" then sending:insert(who) end - else kick(who) end -end - -- loops accepting connections and creating new threads to deal with them function accept(server) while true do -- accept a new connection and start a new coroutine to deal with it local client = server:accept() +print("accepted ", client) if client then - -- start a new connection, non-blockingly, to the forwarding address - local ohost = context[server].ohost - local oport = context[server].oport - local peer = nbkcon(ohost, oport) + -- create contexts for client and peer. + local peer, err = socket.tcp() if peer then context[client] = { last = socket.gettime(), + -- client goes straight to forwarding loop thread = coroutine.create(forward), peer = peer, } - -- make sure peer will be tested for writing in the next select - -- round, which means the connection attempt has finished - sending:insert(peer) context[peer] = { + last = socket.gettime(), peer = client, - thread = coroutine.create(chkcon), + -- peer first tries to connect to forwarding address + thread = coroutine.create(connect), last = socket.gettime() } - -- put both in non-blocking mode - client:settimeout(0) - peer:settimeout(0) + -- resume peer and client so they can do their thing + local ohost = context[server].ohost + local oport = context[server].oport + coroutine.resume(context[peer].thread, peer, ohost, oport) + coroutine.resume(context[client].thread, client) else - -- otherwise just dump the client - client:close() + print(err) + client:close() end end -- tell scheduler we are done for now - coroutine.yield("receiving") + wait(server, "input") end end -- forwards all data arriving to the appropriate peer function forward(who) +print("starting to foward", who) + who:settimeout(0) while true do + -- wait until we have something to read + wait(who, "input") -- try to read as much as possible local data, rec_err, partial = who:receive("*a") -- if we had an error other than timeout, abort - if rec_err and rec_err ~= "timeout" then return error(rec_err) end + if rec_err and rec_err ~= "timeout" then return kick(who) end -- if we got a timeout, we probably have partial results to send data = data or partial - -- renew our timestamp so scheduler sees we are active - context[who].last = socket.gettime() -- forward what we got right away local peer = context[who].peer while true do -- tell scheduler we need to wait until we can send something - coroutine.yield("sending") + wait(who, "output") local ret, snd_err local start = 0 ret, snd_err, start = peer:send(data, start+1) if ret then break - elseif snd_err ~= "timeout" then return error(snd_err) end - -- renew our timestamp so scheduler sees we are active - context[who].last = socket.gettime() + elseif snd_err ~= "timeout" then return kick(who) end end - -- if we are done receiving, we are done with this side of the - -- connection - if not rec_err then return nil end - -- otherwise tell schedule we have to wait for more data to arrive - coroutine.yield("receiving") + -- if we are done receiving, we are done + if not rec_err then return kick(who) end end end --- checks if a connection completed successfully and if it did, starts --- forwarding all data -function chkcon(who) - local ret, err = who:connected() - if ret then - receiving:insert(context[who].peer) - context[who].last = socket.gettime() - coroutine.yield("receiving") - return forward(who) - else return error(err) end -end - -- loop waiting until something happens, restarting the thread to deal with -- what happened, and routing it to wait until something else happens function go() while true do +print("will select for reading") +for i,v in ipairs(receiving) do + print(i, v) +end +print("will select for sending") +for i,v in ipairs(sending) do + print(i, v) +end -- check which sockets are interesting and act on them readable, writable = socket.select(receiving, sending, 3) - -- for all readable connections, resume its thread and route it +print("was readable") +for i,v in ipairs(readable) do + print(i, v) +end +print("was writable") +for i,v in ipairs(writable) do + print(i, v) +end + -- for all readable connections, resume its thread for _, who in ipairs(readable) do receiving:remove(who) - if context[who] then - route(who, coroutine.resume(context[who].thread, who)) - end + coroutine.resume(context[who].thread, who) end -- for all writable connections, do the same for _, who in ipairs(writable) do sending:remove(who) - if context[who] then - route(who, coroutine.resume(context[who].thread, who)) - end + coroutine.resume(context[who].thread, who) end -- put all inactive threads in death row local now = socket.gettime() local deathrow for who, data in pairs(context) do - if data.last then + if data.peer then if now - data.last > TIMEOUT then - -- only create table if someone is doomed + -- only create table if at least one is doomed deathrow = deathrow or {} deathrow[who] = true end -- cgit v1.2.3-55-g6feb