aboutsummaryrefslogtreecommitdiff
path: root/samples/forward.lua
blob: 548a7538fef4a21fbfe2b929e8f40f08f44e5497 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
-- load our favourite library
local socket = require"socket"

-- creates a new set data structure
function newset()
    local reverse = {}
    local set = {}
    return setmetatable(set, {__index = {
        insert = function(set, value)
            if not reverse[value] then
                table.insert(set, value)
                reverse[value] = table.getn(set)
            end
        end,
        remove = function(set, value)
            local index = reverse[value]
            if index then
                reverse[value] = nil
                local top = table.remove(set)
                if top ~= value then 
                    reverse[top] = index
                    set[index] = top
                end 
            end
        end
    }})
end

-- timeout before an inactive thread is kicked
local TIMEOUT = 10
-- set of connections waiting to receive data
local receiving = newset(1)
-- set of sockets waiting to send data
local sending = newset() 
-- context for connections and servers
local context = {}

function wait(who, what)
    if what == "input" then receiving:insert(who)
    else sending:insert(who) end
    context[who].last = socket.gettime()
    coroutine.yield()
end

-- initializes the forward server
function init()
    if table.getn(arg) < 1 then
        print("Usage")
        print("    lua forward.lua <iport:ohost:oport> ...")
        os.exit(1)
    end
    -- for each tunnel, start a new server socket
    for i, v in ipairs(arg) do
        -- capture forwarding parameters
        local iport, ohost, oport = 
            socket.skip(2, string.find(v, "([^:]+):([^:]+):([^:]+)"))
        assert(iport, "invalid arguments")
        -- create our server socket
        local server = assert(socket.bind("*", iport))
        server:settimeout(0) -- we don't want to be killed by bad luck
        -- make sure server is tested for readability
        receiving:insert(server)
        -- add server context
        context[server] = { 
            thread = coroutine.create(accept),
            ohost = ohost,
            oport = oport
        }
    end
end

-- starts a connection in a non-blocking way
function connect(who, host, port)
    who:settimeout(0)
    local ret, err = who:connect(host, port)
    if not ret and err == "timeout" then
        wait(who, "output") 
        ret, err = who:connect(host, port)
        if not ret and err ~= "already connected" then 
            kick(context[who].peer)
            kick(who)
            return
        end
    end
    return forward(who)
end

-- gets rid of a client
function kick(who)
    if who then
        sending:remove(who)
        receiving:remove(who)
        who:close()
        context[who] = nil
    end
end

-- loops accepting connections and creating new threads to deal with them
function accept(server)
    while true do
        -- accept a new connection and start a new coroutine to deal with it
        local client = server:accept()
        if client then
            -- create contexts for client and peer. 
            local peer, err = socket.tcp() 
            if peer then
                context[client] = {
                    last = socket.gettime(),
                    -- client goes straight to forwarding loop
                    thread = coroutine.create(forward),
                    peer = peer,
                }
                context[peer] = {
                    last = socket.gettime(),
                    peer = client,
                    -- peer first tries to connect to forwarding address
                    thread = coroutine.create(connect),
                    last = socket.gettime()
                }
                -- resume peer and client so they can do their thing
                local ohost = context[server].ohost
                local oport = context[server].oport
                coroutine.resume(context[peer].thread, peer, ohost, oport)
                coroutine.resume(context[client].thread, client)
            else 
                print(err)
                client:close()
            end
        end
        -- tell scheduler we are done for now
        wait(server, "input") 
    end
end

-- forwards all data arriving to the appropriate peer
function forward(who)
    who:settimeout(0)
    while true do
        -- wait until we have something to read
        wait(who, "input")
        -- try to read as much as possible
        local data, rec_err, partial = who:receive("*a")
        -- if we had an error other than timeout, abort
        if rec_err and rec_err ~= "timeout" then return kick(who) end
        -- if we got a timeout, we probably have partial results to send
        data = data or partial
        -- forward what we got right away
        local peer = context[who].peer
        while true do
            -- tell scheduler we need to wait until we can send something
            wait(who, "output") 
            local ret, snd_err
            local start = 0
            ret, snd_err, start = peer:send(data, start+1)
            if ret then break 
            elseif snd_err ~= "timeout" then return kick(who) end
        end
        -- if we are done receiving, we are done
        if not rec_err then 
            kick(who) 
            kick(peer)
            break
        end
    end
end

-- loop waiting until something happens, restarting the thread to deal with
-- what happened, and routing it to wait until something else happens
function go()
    while true  do
        -- check which sockets are interesting and act on them
        readable, writable = socket.select(receiving, sending)
        -- for all readable connections, resume its thread 
        for _, who in ipairs(readable) do
            if context[who] then
                receiving:remove(who)
                coroutine.resume(context[who].thread, who)
            end
        end
        -- for all writable connections, do the same
        for _, who in ipairs(writable) do
            if context[who] then
                sending:remove(who)
                coroutine.resume(context[who].thread, who)
            end
        end
        -- put all inactive threads in death row
        local now = socket.gettime()
        local deathrow
        for who, data in pairs(context) do
            if data.peer then
                if  now - data.last > TIMEOUT then
                    -- only create table if at least one is doomed
                    deathrow = deathrow or {} 
                    deathrow[who] = true
                end
            end
        end
        -- finally kick everyone in deathrow
        if deathrow then
            for who in pairs(deathrow) do kick(who) end
        end
    end
end

init()
go()