From 773e35ced30fa2c03ddb2a332bf8a9aebb56aa44 Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Tue, 23 Aug 2005 05:53:14 +0000 Subject: Compiled on Windows. Fixed a bunch of stuff. Almost ready to release. Implemented a nice dispatcher! Non-blocking check-links and forward server use the dispatcher. --- etc/check-links.lua | 31 +++--- etc/dispatch.lua | 296 +++++++++++++++++++++++++++++----------------------- etc/forward.lua | 65 ++++++++++++ 3 files changed, 242 insertions(+), 150 deletions(-) create mode 100644 etc/forward.lua (limited to 'etc') diff --git a/etc/check-links.lua b/etc/check-links.lua index e06cc91..725cd2a 100644 --- a/etc/check-links.lua +++ b/etc/check-links.lua @@ -5,33 +5,26 @@ -- Author: Diego Nehab -- RCS ID: $$ ----------------------------------------------------------------------------- -local dispatch, url, http, handler +local url = require("socket.url") +local dispatch = require("dispatch") +local http = require("socket.http") +dispatch.TIMEOUT = 10 +-- make sure the user knows how to invoke us arg = arg or {} if table.getn(arg) < 1 then print("Usage:\n luasocket check-links.lua [-n] {}") exit() end -if arg[1] ~= "-n" then - -- if using blocking I/O, simulate dispatcher interface - url = require("socket.url") - http = require("socket.http") - handler = { - start = function(self, f) - f() - end, - tcp = socket.tcp - } - http.TIMEOUT = 10 -else - -- if non-blocking I/O was requested, disable dispatcher +-- '-n' means we are running in non-blocking mode +if arg[1] == "-n" then + -- if non-blocking I/O was requested, use real dispatcher interface table.remove(arg, 1) - dispatch = require("dispatch") - dispatch.TIMEOUT = 10 - url = require("socket.url") - http = require("socket.http") - handler = dispatch.newhandler() + handler = dispatch.newhandler("coroutine") +else + -- if using blocking I/O, use fake dispatcher interface + handler = dispatch.newhandler("sequential") end local nthreads = 0 diff --git a/etc/dispatch.lua b/etc/dispatch.lua index e6c14a6..98fa8a8 100644 --- a/etc/dispatch.lua +++ b/etc/dispatch.lua @@ -11,23 +11,33 @@ module("dispatch") -- if too much time goes by without any activity in one of our sockets, we -- just kill it -TIMEOUT = 10 +TIMEOUT = 60 ----------------------------------------------------------------------------- --- Mega hack. Don't try to do this at home. +-- We implement 3 types of dispatchers: +-- sequential +-- coroutine +-- threaded +-- The user can choose whatever one is needed ----------------------------------------------------------------------------- --- Lua 5.1 has coroutine.running(). We need it here, so we use this terrible --- hack to emulate it in Lua itself --- This is very inefficient, but is very good for debugging. -local running -local resume = coroutine.resume -function coroutine.resume(co, ...) - running = co - return resume(co, unpack(arg)) +local handlert = {} + +-- default handler is coroutine +function newhandler(mode) + mode = mode or "coroutine" + return handlert[mode]() end -function coroutine.running() - return running +local function seqstart(self, func) + return func() +end + +-- sequential handler simply calls the functions and doesn't wrap I/O +function handlert.sequential() + return { + tcp = socket.tcp, + start = seqstart + } end ----------------------------------------------------------------------------- @@ -36,15 +46,11 @@ end -- we can't yield across calls to protect, so we rewrite it with coxpcall -- make sure you don't require any module that uses socket.protect before -- loading our hack -function socket.protect(f) - return f -end - function socket.protect(f) return function(...) local co = coroutine.create(f) while true do - local results = {resume(co, unpack(arg))} + local results = {coroutine.resume(co, unpack(arg))} local status = table.remove(results, 1) if not status then if type(results[1]) == 'table' then @@ -61,48 +67,68 @@ function socket.protect(f) end ----------------------------------------------------------------------------- --- socket.tcp() replacement for non-blocking I/O +-- Simple set data structure. O(1) everything. +----------------------------------------------------------------------------- +local function newset() + local reverse = {} + local set = {} + return setmetatable(set, {__index = { + insert = function(set, value) + if not reverse[value] then + table.insert(set, value) + reverse[value] = table.getn(set) + end + end, + remove = function(set, value) + local index = reverse[value] + if index then + reverse[value] = nil + local top = table.remove(set) + if top ~= value then + reverse[top] = index + set[index] = top + end + end + end + }}) +end + +----------------------------------------------------------------------------- +-- socket.tcp() wrapper for the coroutine dispatcher ----------------------------------------------------------------------------- -local function newtrap(dispatcher) - -- try to create underlying socket - local tcp, error = socket.tcp() +local function cowrap(dispatcher, tcp, error) if not tcp then return nil, error end -- put it in non-blocking mode right away tcp:settimeout(0) - -- metatable for trap produces new methods on demand for those that we + -- metatable for wrap produces new methods on demand for those that we -- don't override explicitly. local metat = { __index = function(table, key) table[key] = function(...) - return tcp[key](tcp, unpack(arg)) + arg[1] = tcp + return tcp[key](unpack(arg)) end + return table[key] end} - -- does user want to do his own non-blocking I/O? + -- does our user want to do his own non-blocking I/O? local zero = false - -- create a trap object that will behave just like a real socket object - local trap = { } + -- create a wrap object that will behave just like a real socket object + local wrap = { } -- we ignore settimeout to preserve our 0 timeout, but record whether -- the user wants to do his own non-blocking I/O - function trap:settimeout(mode, value) - if value == 0 then - zero = true - else - zero = false - end + function wrap:settimeout(value, mode) + if value == 0 then zero = true + else zero = false end return 1 end -- send in non-blocking mode and yield on timeout - function trap:send(data, first, last) + function wrap:send(data, first, last) first = (first or 1) - 1 local result, error while true do - -- tell dispatcher we want to keep sending before we yield - dispatcher.sending:insert(tcp) - -- mark time we started waiting - dispatcher.context[tcp].last = socket.gettime() - -- return control to dispatcher + -- return control to dispatcher and tell it we want to send -- if upon return the dispatcher tells us we timed out, -- return an error to whoever called us - if coroutine.yield() == "timeout" then + if coroutine.yield(dispatcher.sending, tcp) == "timeout" then return nil, "timeout" end -- try sending @@ -114,41 +140,35 @@ local function newtrap(dispatcher) end -- receive in non-blocking mode and yield on timeout -- or simply return partial read, if user requested timeout = 0 - function trap:receive(pattern, partial) + function wrap:receive(pattern, partial) local error = "timeout" local value while true do - -- tell dispatcher we want to keep receiving before we yield - dispatcher.receiving:insert(tcp) - -- mark time we started waiting - dispatcher.context[tcp].last = socket.gettime() - -- return control to dispatcher + -- return control to dispatcher and tell it we want to receive -- if upon return the dispatcher tells us we timed out, -- return an error to whoever called us - if coroutine.yield() == "timeout" then + if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then return nil, "timeout" end -- try receiving value, error, partial = tcp:receive(pattern, partial) -- if we are done, or there was an unexpected error, - -- break away from loop + -- break away from loop. also, if the user requested + -- zero timeout, return all we got if (error ~= "timeout") or zero then return value, error, partial end end end -- connect in non-blocking mode and yield on timeout - function trap:connect(host, port) + function wrap:connect(host, port) local result, error = tcp:connect(host, port) - -- mark time we started waiting - dispatcher.context[tcp].last = socket.gettime() if error == "timeout" then - -- tell dispatcher we will be able to write uppon connection - dispatcher.sending:insert(tcp) - -- return control to dispatcher + -- return control to dispatcher. we will be writable when + -- connection succeeds. -- if upon return the dispatcher tells us we have a -- timeout, just abort - if coroutine.yield() == "timeout" then + if coroutine.yield(dispatcher.sending, tcp) == "timeout" then return nil, "timeout" end -- when we come back, check if connection was successful @@ -158,110 +178,124 @@ local function newtrap(dispatcher) else return result, error end end -- accept in non-blocking mode and yield on timeout - function trap:accept() - local result, error = tcp:accept() - while error == "timeout" do - -- mark time we started waiting - dispatcher.context[tcp].last = socket.gettime() - -- tell dispatcher we will be able to read uppon connection - dispatcher.receiving:insert(tcp) - -- return control to dispatcher + function wrap:accept() + while 1 do + -- return control to dispatcher. we will be readable when a + -- connection arrives. -- if upon return the dispatcher tells us we have a -- timeout, just abort - if coroutine.yield() == "timeout" then + if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then return nil, "timeout" end + local client, error = tcp:accept() + if error ~= "timeout" then + return cowrap(dispatcher, client, error) + end end - return result, error end - -- remove thread from context - function trap:close() - dispatcher.context[tcp] = nil + -- remove cortn from context + function wrap:close() + dispatcher.stamp[tcp] = nil + dispatcher.sending.set:remove(tcp) + dispatcher.sending.cortn[tcp] = nil + dispatcher.receiving.set:remove(tcp) + dispatcher.receiving.cortn[tcp] = nil return tcp:close() end - -- add newly created socket to context - dispatcher.context[tcp] = { - thread = coroutine.running() - } - return setmetatable(trap, metat) + return setmetatable(wrap, metat) end + ----------------------------------------------------------------------------- --- Our set data structure +-- Our coroutine dispatcher ----------------------------------------------------------------------------- -local function newset() - local reverse = {} - local set = {} - return setmetatable(set, {__index = { - insert = function(set, value) - if not reverse[value] then - table.insert(set, value) - reverse[value] = table.getn(set) - end - end, - remove = function(set, value) - local index = reverse[value] - if index then - reverse[value] = nil - local top = table.remove(set) - if top ~= value then - reverse[top] = index - set[index] = top - end - end +local cometat = { __index = {} } + +function schedule(cortn, status, operation, tcp) + if status then + if cortn and operation then + operation.set:insert(tcp) + operation.cortn[tcp] = cortn + operation.stamp[tcp] = socket.gettime() end - }}) + else error(operation) end end ------------------------------------------------------------------------------ --- Our dispatcher API. ------------------------------------------------------------------------------ -local metat = { __index = {} } +function kick(operation, tcp) + operation.cortn[tcp] = nil + operation.set:remove(tcp) +end -function metat.__index:start(func) - local co = coroutine.create(func) - assert(coroutine.resume(co)) +function wakeup(operation, tcp) + local cortn = operation.cortn[tcp] + -- if cortn is still valid, wake it up + if cortn then + kick(operation, tcp) + return cortn, coroutine.resume(cortn) + -- othrewise, just get scheduler not to do anything + else + return nil, true + end end -function newhandler() - local dispatcher = { - context = {}, - sending = newset(), - receiving = newset() - } - function dispatcher.tcp() - return newtrap(dispatcher) +function abort(operation, tcp) + local cortn = operation.cortn[tcp] + if cortn then + kick(operation, tcp) + coroutine.resume(cortn, "timeout") end - return setmetatable(dispatcher, metat) end --- step through all active threads -function metat.__index:step() +-- step through all active cortns +function cometat.__index:step() -- check which sockets are interesting and act on them - local readable, writable = socket.select(self.receiving, - self.sending, 1) - -- for all readable connections, resume their threads - for _, who in ipairs(readable) do - if self.context[who] then - self.receiving:remove(who) - assert(coroutine.resume(self.context[who].thread)) - end + local readable, writable = socket.select(self.receiving.set, + self.sending.set, 1) + -- for all readable connections, resume their cortns and reschedule + -- when they yield back to us + for _, tcp in ipairs(readable) do + schedule(wakeup(self.receiving, tcp)) end -- for all writable connections, do the same - for _, who in ipairs(writable) do - if self.context[who] then - self.sending:remove(who) - assert(coroutine.resume(self.context[who].thread)) - end + for _, tcp in ipairs(writable) do + schedule(wakeup(self.sending, tcp)) end - -- politely ask replacement I/O functions in idle threads to + -- politely ask replacement I/O functions in idle cortns to -- return reporting a timeout local now = socket.gettime() - for who, data in pairs(self.context) do - if data.last and now - data.last > TIMEOUT then - self.sending:remove(who) - self.receiving:remove(who) - assert(coroutine.resume(self.context[who].thread, "timeout")) + for tcp, stamp in pairs(self.stamp) do + if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then + abort(self.sending, tcp) + abort(self.receiving, tcp) end end end + +function cometat.__index:start(func) + local cortn = coroutine.create(func) + schedule(cortn, coroutine.resume(cortn)) +end + +function handlert.coroutine() + local stamp = {} + local dispatcher = { + stamp = stamp, + sending = { + name = "sending", + set = newset(), + cortn = {}, + stamp = stamp + }, + receiving = { + name = "receiving", + set = newset(), + cortn = {}, + stamp = stamp + }, + } + function dispatcher.tcp() + return cowrap(dispatcher, socket.tcp()) + end + return setmetatable(dispatcher, cometat) +end + diff --git a/etc/forward.lua b/etc/forward.lua new file mode 100644 index 0000000..eac98ae --- /dev/null +++ b/etc/forward.lua @@ -0,0 +1,65 @@ +-- load our favourite library +local dispatch = require("dispatch") +local handler = dispatch.newhandler() + +-- make sure the user knows how to invoke us +if table.getn(arg) < 1 then + print("Usage") + print(" lua forward.lua ...") + os.exit(1) +end + +-- function to move data from one socket to the other +local function move(foo, bar) + local live + while 1 do + local data, error, partial = foo:receive(2048) + live = data or error == "timeout" + data = data or partial + local result, error = bar:send(data) + if not live or not result then + foo:close() + bar:close() + break + end + end +end + +-- for each tunnel, start a new server +for i, v in ipairs(arg) do + -- capture forwarding parameters + local _, _, iport, ohost, oport = string.find(v, "([^:]+):([^:]+):([^:]+)") + assert(iport, "invalid arguments") + -- create our server socket + local server = assert(handler.tcp()) + assert(server:setoption("reuseaddr", true)) + assert(server:bind("*", iport)) + assert(server:listen(32)) + -- handler for the server object loops accepting new connections + handler:start(function() + while 1 do + local client = assert(server:accept()) + assert(client:settimeout(0)) + -- for each new connection, start a new client handler + handler:start(function() + -- handler tries to connect to peer + local peer = assert(handler.tcp()) + assert(peer:settimeout(0)) + assert(peer:connect(ohost, oport)) + -- if sucessful, starts a new handler to send data from + -- client to peer + handler:start(function() + move(client, peer) + end) + -- afte starting new handler, enter in loop sending data from + -- peer to client + move(peer, client) + end) + end + end) +end + +-- simply loop stepping the server +while 1 do + handler:step() +end -- cgit v1.2.3-55-g6feb