1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
|
-----------------------------------------------------------------------------
-- TFTP support for the Lua language
-- LuaSocket toolkit.
-- Author: Diego Nehab
-- RCS ID: $Id$
-----------------------------------------------------------------------------
-----------------------------------------------------------------------------
-- Load required files
-----------------------------------------------------------------------------
local socket = require("socket")
local ltn12 = require("ltn12")
local url = require("url")
-----------------------------------------------------------------------------
-- Program constants
-----------------------------------------------------------------------------
local char = string.char
local byte = string.byte
PORT = 69
local OP_RRQ = 1
local OP_WRQ = 2
local OP_DATA = 3
local OP_ACK = 4
local OP_ERROR = 5
local OP_INV = {"RRQ", "WRQ", "DATA", "ACK", "ERROR"}
-----------------------------------------------------------------------------
-- Packet creation functions
-----------------------------------------------------------------------------
local function RRQ(source, mode)
return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0)
end
local function WRQ(source, mode)
return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0)
end
local function ACK(block)
local low, high
low = math.mod(block, 256)
high = (block - low)/256
return char(0, OP_ACK, high, low)
end
local function get_OP(dgram)
local op = byte(dgram, 1)*256 + byte(dgram, 2)
return op
end
-----------------------------------------------------------------------------
-- Packet analysis functions
-----------------------------------------------------------------------------
local function split_DATA(dgram)
local block = byte(dgram, 3)*256 + byte(dgram, 4)
local data = string.sub(dgram, 5)
return block, data
end
local function get_ERROR(dgram)
local code = byte(dgram, 3)*256 + byte(dgram, 4)
local msg
_,_, msg = string.find(dgram, "(.*)\000", 5)
return string.format("error code %d: %s", code, msg)
end
-----------------------------------------------------------------------------
-- The real work
-----------------------------------------------------------------------------
local function tget(gett)
local retries, dgram, sent, datahost, dataport, code
local last = 0
local con = socket.try(socket.udp())
-- convert from name to ip if needed
gett.host = socket.try(socket.dns.toip(gett.host))
con:settimeout(1)
-- first packet gives data host/port to be used for data transfers
retries = 0
repeat
sent = socket.try(con:sendto(RRQ(gett.path, "octet"),
gett.host, gett.port))
dgram, datahost, dataport = con:receivefrom()
retries = retries + 1
until dgram or datahost ~= "timeout" or retries > 5
socket.try(dgram, datahost)
-- associate socket with data host/port
socket.try(con:setpeername(datahost, dataport))
-- default sink
local sink = gett.sink or ltn12.sink.null()
-- process all data packets
while 1 do
-- decode packet
code = get_OP(dgram)
socket.try(code ~= OP_ERROR, get_ERROR(dgram))
socket.try(code == OP_DATA, "unhandled opcode " .. code)
-- get data packet parts
local block, data = split_DATA(dgram)
-- if not repeated, write
if block == last+1 then
socket.try(sink(data))
last = block
end
-- last packet brings less than 512 bytes of data
if string.len(data) < 512 then
socket.try(con:send(ACK(block)))
socket.try(con:close())
socket.try(sink(nil))
return 1
end
-- get the next packet
retries = 0
repeat
sent = socket.try(con:send(ACK(last)))
dgram, err = con:receive()
retries = retries + 1
until dgram or err ~= "timeout" or retries > 5
socket.try(dgram, err)
end
end
local default = {
port = PORT,
path ="/",
scheme = "tftp"
}
local function parse(u)
local t = socket.try(url.parse(u, default))
socket.try(t.scheme == "tftp", "invalid scheme '" .. t.scheme .. "'")
socket.try(t.host, "invalid host")
return t
end
local function sget(u)
local gett = parse(u)
local t = {}
gett.sink = ltn12.sink.table(t)
tget(gett)
return table.concat(t)
end
get = socket.protect(function(gett)
if type(gett) == "string" then return sget(gett)
else return tget(gett) end
end)
|