From 8d4e240f6ae50d9b22ddc44f5e207018935da907 Mon Sep 17 00:00:00 2001
From: Diego Nehab
@@ -150,7 +150,7 @@ error.
@@ -206,7 +206,7 @@ Here are a few examples with the simple interface:
@@ -239,7 +239,7 @@ and
-- loads the FTP module and any libraries it requires
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
-- load the ftp support
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
-- Log as user "anonymous" on server "ftp.tecgraf.puc-rio.br",
-- and get file "lua.tar.gz" from directory "pub/lua" as binary.
@@ -159,9 +159,9 @@ f, e = ftp.get("ftp://ftp.tecgraf.puc-rio.br/pub/lua/lua.tar.gz;type=i")
-- load needed modules
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
local ltn12 = require("ltn12")
-local url = require("url")
+local url = require("socket.url")
-- a function that returns a directory listing
function nlst(u)
@@ -230,7 +230,7 @@ message describing the reason for failure.
-- load the ftp support
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
-- Log as user "fulano" on server "ftp.example.com",
-- using password "silva", and store a file "README" with contents
@@ -241,7 +241,7 @@ f, e = ftp.put("ftp://fulano:silva@ftp.example.com/README",
-- load the ftp support
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
local ltn12 = require("ltn12")
-- Log as user "fulano" on server "ftp.example.com",
diff --git a/doc/http.html b/doc/http.html
index 4cbbe95..af58571 100644
--- a/doc/http.html
+++ b/doc/http.html
@@ -62,7 +62,7 @@ To obtain the http namespace, run:
-- loads the HTTP module and any libraries it requires
-local http = require("http")
+local http = require("socket.http")
-- load the http module
-http = require("http")
+http = require("socket.http")
-- connect to server "www.tecgraf.puc-rio.br" and retrieves this manual
-- file from "/luasocket/http.html"
@@ -231,7 +231,7 @@ And here is an example using the generic interface:
-- load the http module
-http = require("http")
+http = require("socket.http")
-- Requests information about a document, without downloading it.
-- Useful, for example, if you want to display a download gauge and need
@@ -276,7 +276,7 @@ authentication is required.
-- load required modules
-http = require("http")
+http = require("socket.http")
mime = require("mime")
-- Connect to server "www.example.com" and tries to retrieve
diff --git a/doc/introduction.html b/doc/introduction.html
index f8fe078..c88fa40 100644
--- a/doc/introduction.html
+++ b/doc/introduction.html
@@ -182,7 +182,7 @@ program.
-- load namespace
local socket = require("socket")
-- create a TCP socket and bind it to the local host, at any port
-local server = socket.try(socket.bind("*", 0))
+local server = assert(socket.bind("*", 0))
-- find out which port the OS chose for us
local ip, port = server:getsockname()
-- print a message informing what's up
@@ -287,13 +287,13 @@ local host, port = "localhost", 13
-- load namespace
local socket = require("socket")
-- convert host name to ip address
-local ip = socket.try(socket.dns.toip(host))
+local ip = assert(socket.dns.toip(host))
-- create a new UDP object
-local udp = socket.try(socket.udp())
+local udp = assert(socket.udp())
-- contact daytime host
-socket.try(udp:sendto("anything", ip, port))
+assert(udp:sendto("anything", ip, port))
-- retrieve the answer and print results
-io.write(socket.try((udp:receive())))
+io.write(assert(udp:receive()))
diff --git a/doc/ltn12.html b/doc/ltn12.html
index 44fcbe4..c5a0f59 100644
--- a/doc/ltn12.html
+++ b/doc/ltn12.html
@@ -271,7 +271,7 @@ The function returns the sink and the table used to store the chunks.
-- load needed modules
-local http = require("http")
+local http = require("socket.http")
local ltn12 = require("ltn12")
-- a simplified http.get function
diff --git a/doc/smtp.html b/doc/smtp.html
index 8feae3e..bd18bfa 100644
--- a/doc/smtp.html
+++ b/doc/smtp.html
@@ -69,7 +69,7 @@ To obtain the smtp namespace, run:
-- loads the SMTP module and everything it requires
-local smtp = require("smtp")
+local smtp = require("socket.smtp")
-- load the smtp support
-local smtp = require("smtp")
+local smtp = require("socket.smtp")
-- Connects to server "localhost" and sends a message to users
-- "fulano@example.com", "beltrano@example.com",
@@ -329,7 +329,7 @@ as listed in the introduction.
-- load the smtp support and its friends
-local smtp = require("smtp")
+local smtp = require("socket.smtp")
local mime = require("mime")
local ltn12 = require("ltn12")
diff --git a/doc/socket.html b/doc/socket.html
index f638fd9..18c71d1 100644
--- a/doc/socket.html
+++ b/doc/socket.html
@@ -145,7 +145,10 @@ socket.protect(func)
-Converts a function that throws exceptions into a safe function.
+Converts a function that throws exceptions into a safe function. This
+function only catches exceptions thrown by the try
+and newtry functions. It does not catch normal
+Lua errors.
@@ -346,7 +349,9 @@ socket.try(ret1 [, ret2 ... retN])
-Throws an exception in case of error.
+Throws an exception in case of error. The exception can only be caught
+by the protect function. It does not explode
+into an error message.
diff --git a/doc/url.html b/doc/url.html
index 56e1ef5..ac84d24 100644
--- a/doc/url.html
+++ b/doc/url.html
@@ -52,7 +52,7 @@ To obtain the url namespace, run:
-- loads the URL module
-local url = require("url")
+local url = require("socket.url")
@@ -193,7 +193,7 @@ The function returns the encoded string.
-- load url module
-url = require("url")
+url = require("socket.url")
code = url.escape("/#?;")
-- code = "%2f%23%3f%3b"
@@ -239,7 +239,7 @@ parsed_url = {
-- load url module
-url = require("url")
+url = require("socket.url")
parsed_url = url.parse("http://www.example.com/cgilua/index.lua?a=2#there")
-- parsed_url = {
diff --git a/src/buffer.c b/src/buffer.c
index 0ec7b4d..45cd0f2 100644
--- a/src/buffer.c
+++ b/src/buffer.c
@@ -123,7 +123,7 @@ int buf_meth_receive(lua_State *L, p_buf buf) {
else if (p[0] == '*' && p[1] == 'a') err = recvall(buf, &b);
else luaL_argcheck(L, 0, 2, "invalid receive pattern");
/* get a fixed number of bytes */
- } else err = recvraw(buf, (size_t) lua_tonumber(L, 2), &b);
+ } else err = recvraw(buf, (size_t) lua_tonumber(L, 2)-size, &b);
/* check if there was an error */
if (err != IO_DONE) {
/* we can't push anyting in the stack before pushing the
diff --git a/src/inet.c b/src/inet.c
index e2afcdf..d713643 100644
--- a/src/inet.c
+++ b/src/inet.c
@@ -220,7 +220,6 @@ const char *inet_tryconnect(p_sock ps, const char *address,
}
} else remote.sin_family = AF_UNSPEC;
err = sock_connect(ps, (SA *) &remote, sizeof(remote), tm);
- if (err != IO_DONE) sock_destroy(ps);
return sock_strerror(err);
}
diff --git a/src/luasocket.c b/src/luasocket.c
index 4b829f8..8f13dbc 100644
--- a/src/luasocket.c
+++ b/src/luasocket.c
@@ -87,7 +87,7 @@ static int global_unload(lua_State *L) {
static int base_open(lua_State *L) {
if (sock_open()) {
/* export functions (and leave namespace table on top of stack) */
- luaL_module(L, "socket", func, 0);
+ luaL_openlib(L, "socket", func, 0);
#ifdef LUASOCKET_DEBUG
lua_pushstring(L, "DEBUG");
lua_pushboolean(L, 1);
@@ -108,7 +108,7 @@ static int base_open(lua_State *L) {
/*-------------------------------------------------------------------------*\
* Initializes all library modules.
\*-------------------------------------------------------------------------*/
-LUASOCKET_API int luaopen_lsocket(lua_State *L) {
+LUASOCKET_API int luaopen_csocket(lua_State *L) {
int i;
base_open(L);
for (i = 0; mod[i].name; i++) mod[i].func(L);
diff --git a/src/luasocket.h b/src/luasocket.h
index db54a18..768e335 100644
--- a/src/luasocket.h
+++ b/src/luasocket.h
@@ -13,7 +13,7 @@
/*-------------------------------------------------------------------------*\
* Current luasocket version
\*-------------------------------------------------------------------------*/
-#define LUASOCKET_VERSION "LuaSocket 2.0 (beta3)"
+#define LUASOCKET_VERSION "LuaSocket 2.0"
#define LUASOCKET_COPYRIGHT "Copyright (C) 2004-2005 Diego Nehab"
#define LUASOCKET_AUTHORS "Diego Nehab"
@@ -27,6 +27,6 @@
/*-------------------------------------------------------------------------*\
* Initializes the library.
\*-------------------------------------------------------------------------*/
-LUASOCKET_API int luaopen_socket(lua_State *L);
+LUASOCKET_API int luaopen_csocket(lua_State *L);
#endif /* LUASOCKET_H */
diff --git a/src/mime.c b/src/mime.c
index dcc4af3..67f9f5b 100644
--- a/src/mime.c
+++ b/src/mime.c
@@ -78,9 +78,9 @@ static UC b64unbase[256];
/*-------------------------------------------------------------------------*\
* Initializes module
\*-------------------------------------------------------------------------*/
-MIME_API int luaopen_lmime(lua_State *L)
+MIME_API int luaopen_cmime(lua_State *L)
{
- luaL_module(L, "mime", func, 0);
+ luaL_openlib(L, "mime", func, 0);
/* initialize lookup tables */
qpsetup(qpclass, qpunbase);
b64setup(b64unbase);
diff --git a/src/mime.h b/src/mime.h
index 688d043..d596861 100644
--- a/src/mime.h
+++ b/src/mime.h
@@ -19,6 +19,6 @@
#define MIME_API extern
#endif
-MIME_API int luaopen_mime(lua_State *L);
+MIME_API int luaopen_cmime(lua_State *L);
#endif /* MIME_H */
diff --git a/src/mime.lua b/src/mime.lua
index 4d5bdba..6492a96 100644
--- a/src/mime.lua
+++ b/src/mime.lua
@@ -8,9 +8,10 @@
-----------------------------------------------------------------------------
-- Declare module and import dependencies
-----------------------------------------------------------------------------
+package.loaded.base = _G
local base = require("base")
local ltn12 = require("ltn12")
-local mime = require("lmime")
+local mime = require("cmime")
module("mime")
-- encode, decode and wrap algorithm tables
diff --git a/src/socket.h b/src/socket.h
index 368c2b6..639229d 100644
--- a/src/socket.h
+++ b/src/socket.h
@@ -45,11 +45,15 @@ int sock_sendto(p_sock ps, const char *data, size_t count,
size_t *sent, SA *addr, socklen_t addr_len, p_tm tm);
int sock_recvfrom(p_sock ps, char *data, size_t count,
size_t *got, SA *addr, socklen_t *addr_len, p_tm tm);
+
void sock_setnonblocking(p_sock ps);
void sock_setblocking(p_sock ps);
+
+int sock_waitfd(int fd, int sw, p_tm tm);
int sock_select(int n, fd_set *rfds, fd_set *wfds, fd_set *efds, p_tm tm);
int sock_connect(p_sock ps, SA *addr, socklen_t addr_len, p_tm tm);
+int sock_connected(p_sock ps, p_tm tm);
int sock_create(p_sock ps, int domain, int type, int protocol);
int sock_bind(p_sock ps, SA *addr, socklen_t addr_len);
int sock_listen(p_sock ps, int backlog);
diff --git a/src/socket.lua b/src/socket.lua
index 1c82750..f3563e7 100644
--- a/src/socket.lua
+++ b/src/socket.lua
@@ -7,10 +7,11 @@
-----------------------------------------------------------------------------
-- Declare module and import dependencies
-----------------------------------------------------------------------------
+package.loaded.base = _G
local base = require("base")
local string = require("string")
local math = require("math")
-local socket = require("lsocket")
+local socket = require("csocket")
module("socket")
-----------------------------------------------------------------------------
diff --git a/src/tcp.c b/src/tcp.c
index 0b3706b..3a84191 100644
--- a/src/tcp.c
+++ b/src/tcp.c
@@ -20,6 +20,7 @@
\*=========================================================================*/
static int global_create(lua_State *L);
static int meth_connect(lua_State *L);
+static int meth_connected(lua_State *L);
static int meth_listen(lua_State *L);
static int meth_bind(lua_State *L);
static int meth_send(lua_State *L);
@@ -45,6 +46,7 @@ static luaL_reg tcp[] = {
{"bind", meth_bind},
{"close", meth_close},
{"connect", meth_connect},
+ {"connected", meth_connected},
{"dirty", meth_dirty},
{"getfd", meth_getfd},
{"getpeername", meth_getpeername},
@@ -113,12 +115,12 @@ static int meth_receive(lua_State *L) {
}
static int meth_getstats(lua_State *L) {
- p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1);
+ p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1);
return buf_meth_getstats(L, &tcp->buf);
}
static int meth_setstats(lua_State *L) {
- p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1);
+ p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1);
return buf_meth_setstats(L, &tcp->buf);
}
@@ -224,6 +226,22 @@ static int meth_connect(lua_State *L)
return 1;
}
+static int meth_connected(lua_State *L)
+{
+ static t_tm tm = {-1, -1};
+ p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{master}", 1);
+ int err = sock_connected(&tcp->sock, &tm);
+ if (err != IO_DONE) {
+ lua_pushnil(L);
+ lua_pushstring(L, sock_strerror(err));
+ return 2;
+ }
+ /* turn master object into a client object */
+ aux_setclass(L, "tcp{client}", 1);
+ lua_pushnumber(L, 1);
+ return 1;
+}
+
/*-------------------------------------------------------------------------*\
* Closes socket used by object
\*-------------------------------------------------------------------------*/
diff --git a/src/unix.c b/src/unix.c
index 1e0e252..c169268 100644
--- a/src/unix.c
+++ b/src/unix.c
@@ -32,6 +32,8 @@ static int meth_settimeout(lua_State *L);
static int meth_getfd(lua_State *L);
static int meth_setfd(lua_State *L);
static int meth_dirty(lua_State *L);
+static int meth_getstats(lua_State *L);
+static int meth_setstats(lua_State *L);
static const char *unix_tryconnect(p_unix un, const char *path);
static const char *unix_trybind(p_unix un, const char *path);
@@ -46,6 +48,8 @@ static luaL_reg un[] = {
{"connect", meth_connect},
{"dirty", meth_dirty},
{"getfd", meth_getfd},
+ {"getstats", meth_getstats},
+ {"setstats", meth_setstats},
{"listen", meth_listen},
{"receive", meth_receive},
{"send", meth_send},
@@ -75,7 +79,7 @@ static luaL_reg func[] = {
/*-------------------------------------------------------------------------*\
* Initializes module
\*-------------------------------------------------------------------------*/
-int unix_open(lua_State *L) {
+int luaopen_socketunix(lua_State *L) {
/* create classes */
aux_newclass(L, "unix{master}", un);
aux_newclass(L, "unix{client}", un);
@@ -84,11 +88,9 @@ int unix_open(lua_State *L) {
aux_add2group(L, "unix{master}", "unix{any}");
aux_add2group(L, "unix{client}", "unix{any}");
aux_add2group(L, "unix{server}", "unix{any}");
- aux_add2group(L, "unix{client}", "unix{client,server}");
- aux_add2group(L, "unix{server}", "unix{client,server}");
/* define library functions */
- luaL_openlib(L, NULL, func, 0);
- return 0;
+ luaL_openlib(L, "socket", func, 0);
+ return 1;
}
/*=========================================================================*\
@@ -107,6 +109,16 @@ static int meth_receive(lua_State *L) {
return buf_meth_receive(L, &un->buf);
}
+static int meth_getstats(lua_State *L) {
+ p_unix un = (p_unix) aux_checkclass(L, "unix{client}", 1);
+ return buf_meth_getstats(L, &un->buf);
+}
+
+static int meth_setstats(lua_State *L) {
+ p_unix un = (p_unix) aux_checkclass(L, "unix{client}", 1);
+ return buf_meth_setstats(L, &un->buf);
+}
+
/*-------------------------------------------------------------------------*\
* Just call option handler
\*-------------------------------------------------------------------------*/
@@ -250,7 +262,8 @@ static int meth_close(lua_State *L)
{
p_unix un = (p_unix) aux_checkgroup(L, "unix{any}", 1);
sock_destroy(&un->sock);
- return 0;
+ lua_pushnumber(L, 1);
+ return 1;
}
/*-------------------------------------------------------------------------*\
@@ -277,7 +290,7 @@ static int meth_listen(lua_State *L)
\*-------------------------------------------------------------------------*/
static int meth_shutdown(lua_State *L)
{
- p_unix un = (p_unix) aux_checkgroup(L, "unix{client}", 1);
+ p_unix un = (p_unix) aux_checkclass(L, "unix{client}", 1);
const char *how = luaL_optstring(L, 2, "both");
switch (how[0]) {
case 'b':
diff --git a/src/unix.h b/src/unix.h
index 7b2a5c5..aaaef3d 100644
--- a/src/unix.h
+++ b/src/unix.h
@@ -23,6 +23,6 @@ typedef struct t_unix_ {
} t_unix;
typedef t_unix *p_unix;
-int unix_open(lua_State *L);
+int luaopen_socketunix(lua_State *L);
#endif /* UNIX_H */
diff --git a/src/usocket.c b/src/usocket.c
index c1ab725..3428a0c 100644
--- a/src/usocket.c
+++ b/src/usocket.c
@@ -22,7 +22,7 @@
#define WAITFD_R POLLIN
#define WAITFD_W POLLOUT
#define WAITFD_C (POLLIN|POLLOUT)
-static int sock_waitfd(int fd, int sw, p_tm tm) {
+int sock_waitfd(int fd, int sw, p_tm tm) {
int ret;
struct pollfd pfd;
pfd.fd = fd;
@@ -44,7 +44,7 @@ static int sock_waitfd(int fd, int sw, p_tm tm) {
#define WAITFD_W 2
#define WAITFD_C (WAITFD_R|WAITFD_W)
-static int sock_waitfd(int fd, int sw, p_tm tm) {
+int sock_waitfd(int fd, int sw, p_tm tm) {
int ret;
fd_set rfds, wfds, *rp, *wp;
struct timeval tv, *tp;
@@ -166,12 +166,20 @@ int sock_connect(p_sock ps, SA *addr, socklen_t len, p_tm tm) {
while ((err = errno) == EINTR);
/* if connection failed immediately, return error code */
if (err != EINPROGRESS && err != EAGAIN) return err;
+ /* zero timeout case optimization */
+ if (tm_iszero(tm)) return IO_TIMEOUT;
/* wait until we have the result of the connection attempt or timeout */
- if ((err = sock_waitfd(*ps, WAITFD_C, tm)) == IO_CLOSED) {
- /* finaly find out if we succeeded connecting */
+ return sock_connected(ps, tm);
+}
+
+/*-------------------------------------------------------------------------*\
+* Checks if socket is connected, or return reason for failure
+\*-------------------------------------------------------------------------*/
+int sock_connected(p_sock ps, p_tm tm) {
+ int err;
+ if ((err = sock_waitfd(*ps, WAITFD_C, tm) == IO_CLOSED)) {
if (recv(*ps, (char *) &err, 0, 0) == 0) return IO_DONE;
else return errno;
- /* timed out or some weirder error */
} else return err;
}
@@ -321,13 +329,17 @@ void sock_setnonblocking(p_sock ps) {
int sock_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) {
*hp = gethostbyaddr(addr, len, AF_INET);
if (*hp) return IO_DONE;
- else return h_errno;
+ else if (h_errno) return h_errno;
+ else if (errno) return errno;
+ else return IO_UNKNOWN;
}
int sock_gethostbyname(const char *addr, struct hostent **hp) {
*hp = gethostbyname(addr);
if (*hp) return IO_DONE;
- else return h_errno;
+ else if (h_errno) return h_errno;
+ else if (errno) return errno;
+ else return IO_UNKNOWN;
}
/*-------------------------------------------------------------------------*\
diff --git a/src/wsocket.c b/src/wsocket.c
index 69fac4d..c0686cd 100644
--- a/src/wsocket.c
+++ b/src/wsocket.c
@@ -45,7 +45,7 @@ int sock_close(void) {
#define WAITFD_E 4
#define WAITFD_C (WAITFD_E|WAITFD_W)
-static int sock_waitfd(t_sock fd, int sw, p_tm tm) {
+int sock_waitfd(t_sock fd, int sw, p_tm tm) {
int ret;
fd_set rfds, wfds, efds, *rp = NULL, *wp = NULL, *ep = NULL;
struct timeval tv, *tp = NULL;
@@ -118,7 +118,17 @@ int sock_connect(p_sock ps, SA *addr, socklen_t len, p_tm tm) {
/* make sure the system is trying to connect */
err = WSAGetLastError();
if (err != WSAEWOULDBLOCK && err != WSAEINPROGRESS) return err;
+ /* zero timeout case optimization */
+ if (tm_iszero(tm)) return IO_TIMEOUT;
/* we wait until something happens */
+ return sock_connected(ps, tm);
+}
+
+/*-------------------------------------------------------------------------*\
+* Check if socket is connected
+\*-------------------------------------------------------------------------*/
+int sock_connected(p_sock ps) {
+ int err;
if ((err = sock_waitfd(*ps, WAITFD_C, tm)) == IO_CLOSED) {
int len = sizeof(err);
/* give windows time to set the error (yes, disgusting) */
@@ -126,9 +136,8 @@ int sock_connect(p_sock ps, SA *addr, socklen_t len, p_tm tm) {
/* find out why we failed */
getsockopt(*ps, SOL_SOCKET, SO_ERROR, (char *)&err, &len);
/* we KNOW there was an error. if why is 0, we will return
- * "unknown error", but it's not really our fault */
+ * "unknown error", but it's not really our fault */
return err > 0? err: IO_UNKNOWN;
- /* here we deal with the case in which it worked, timedout or weird errors */
} else return err;
}
diff --git a/test/dicttest.lua b/test/dicttest.lua
index a37ec8d..7ac7811 100644
--- a/test/dicttest.lua
+++ b/test/dicttest.lua
@@ -1,3 +1,3 @@
local dict = require"socket.dict"
-for i,v in dict.get("dict://dell-diego/d:banana") do print(v) end
+for i,v in dict.get("dict://localhost/d:teste") do print(v) end
diff --git a/test/httptest.lua b/test/httptest.lua
index 8862ceb..2335fcb 100644
--- a/test/httptest.lua
+++ b/test/httptest.lua
@@ -23,7 +23,7 @@ http.TIMEOUT = 10
local t = socket.gettime()
host = host or "diego.student.princeton.edu"
-proxy = proxy or "http://dell-diego:3128"
+proxy = proxy or "http://localhost:3128"
prefix = prefix or "/luasocket-test"
cgiprefix = cgiprefix or "/luasocket-test-cgi"
index_file = "test/index.html"
diff --git a/test/testclnt.lua b/test/testclnt.lua
index c2c782c..e3f2b94 100644
--- a/test/testclnt.lua
+++ b/test/testclnt.lua
@@ -465,16 +465,14 @@ print("Testing " .. 2*size .. " bytes")
remote(string.format([[
data:send(string.rep("a", %d))
socket.sleep(0.5)
- data:send(string.rep("b", %d))
+ data:send(string.rep("b", %d) .. "\n")
]], size, size))
local err = "timeout"
local part = ""
local str
data:settimeout(0)
while 1 do
- local needed = 2*size - string.len(part)
- assert(needed > 0, "weird")
- str, err, part = data:receive(needed, part)
+ str, err, part = data:receive("*l", part)
if err ~= "timeout" then break end
end
assert(str == (string.rep("a", size) .. string.rep("b", size)))
@@ -482,15 +480,14 @@ remote(string.format([[
remote(string.format([[
str = data:receive(%d)
socket.sleep(0.5)
- str = data:receive(%d, str)
+ str = data:receive(2*%d, str)
data:send(str)
]], size, size))
data:settimeout(0)
- local sofar = 1
+ local start = 0
while 1 do
- _, err, part = data:send(str, sofar)
+ ret, err, start = data:send(str, start+1)
if err ~= "timeout" then break end
- sofar = sofar + part
end
data:send("\n")
data:settimeout(-1)
@@ -501,6 +498,7 @@ end
------------------------------------------------------------------------
+
test("method registration")
test_methods(socket.tcp(), {
"accept",
@@ -622,7 +620,7 @@ test_nonblocking(17)
test_nonblocking(200)
test_nonblocking(4091)
test_nonblocking(80199)
-test_nonblocking(8000000)
+test_nonblocking(800000)
test_nonblocking(80199)
test_nonblocking(4091)
test_nonblocking(200)
diff --git a/test/testsrvr.lua b/test/testsrvr.lua
index 2408e83..f1972c2 100644
--- a/test/testsrvr.lua
+++ b/test/testsrvr.lua
@@ -9,6 +9,7 @@ while 1 do
while 1 do
command = assert(control:receive());
assert(control:send(ack));
+ print(command);
(loadstring(command))();
end
end
diff --git a/test/utestclnt.lua b/test/utestclnt.lua
new file mode 100644
index 0000000..f002c6e
--- /dev/null
+++ b/test/utestclnt.lua
@@ -0,0 +1,644 @@
+require"socket"
+local socket = require"socket.unix"
+
+host = "luasocket"
+
+function pass(...)
+ local s = string.format(unpack(arg))
+ io.stderr:write(s, "\n")
+end
+
+function fail(...)
+ local s = string.format(unpack(arg))
+ io.stderr:write("ERROR: ", s, "!\n")
+socket.sleep(3)
+ os.exit()
+end
+
+function warn(...)
+ local s = string.format(unpack(arg))
+ io.stderr:write("WARNING: ", s, "\n")
+end
+
+function remote(...)
+ local s = string.format(unpack(arg))
+ s = string.gsub(s, "\n", ";")
+ s = string.gsub(s, "%s+", " ")
+ s = string.gsub(s, "^%s*", "")
+ control:send(s .. "\n")
+ control:receive()
+end
+
+function test(test)
+ io.stderr:write("----------------------------------------------\n",
+ "testing: ", test, "\n",
+ "----------------------------------------------\n")
+end
+
+function uconnect(path)
+ local u = assert(socket.unix())
+ assert(u:connect(path))
+ return u
+end
+
+function ubind(path)
+ local u = assert(socket.unix())
+ assert(u:bind(path))
+ assert(u:listen(5))
+ return u
+end
+
+function check_timeout(tm, sl, elapsed, err, opp, mode, alldone)
+ if tm < sl then
+ if opp == "send" then
+ if not err then warn("must be buffered")
+ elseif err == "timeout" then pass("proper timeout")
+ else fail("unexpected error '%s'", err) end
+ else
+ if err ~= "timeout" then fail("should have timed out")
+ else pass("proper timeout") end
+ end
+ else
+ if mode == "total" then
+ if elapsed > tm then
+ if err ~= "timeout" then fail("should have timed out")
+ else pass("proper timeout") end
+ elseif elapsed < tm then
+ if err then fail(err)
+ else pass("ok") end
+ else
+ if alldone then
+ if err then fail("unexpected error '%s'", err)
+ else pass("ok") end
+ else
+ if err ~= "timeout" then fail(err)
+ else pass("proper timeoutk") end
+ end
+ end
+ else
+ if err then fail(err)
+ else pass("ok") end
+ end
+ end
+end
+
+if not socket.DEBUG then
+ fail("Please define LUASOCKET_DEBUG and recompile LuaSocket")
+end
+
+io.stderr:write("----------------------------------------------\n",
+"LuaSocket Test Procedures\n",
+"----------------------------------------------\n")
+
+start = socket.gettime()
+
+function reconnect()
+ io.stderr:write("attempting data connection... ")
+ if data then data:close() end
+ remote [[
+ i = i or 1
+ if data then data:close() data = nil end
+ print("accepting")
+ data = server:accept()
+ i = i + 1
+ print("done " .. i)
+ ]]
+ data, err = uconnect(host, port)
+ if not data then fail(err)
+ else pass("connected!") end
+end
+
+pass("attempting control connection...")
+control, err = uconnect(host, port)
+if err then fail(err)
+else pass("connected!") end
+
+------------------------------------------------------------------------
+function test_methods(sock, methods)
+ for _, v in methods do
+ if type(sock[v]) ~= "function" then
+ fail(sock.class .. " method '" .. v .. "' not registered")
+ end
+ end
+ pass(sock.class .. " methods are ok")
+end
+
+------------------------------------------------------------------------
+function test_mixed(len)
+ reconnect()
+ local inter = math.ceil(len/4)
+ local p1 = "unix " .. string.rep("x", inter) .. "line\n"
+ local p2 = "dos " .. string.rep("y", inter) .. "line\r\n"
+ local p3 = "raw " .. string.rep("z", inter) .. "bytes"
+ local p4 = "end" .. string.rep("w", inter) .. "bytes"
+ local bp1, bp2, bp3, bp4
+remote (string.format("str = data:receive(%d)",
+ string.len(p1)+string.len(p2)+string.len(p3)+string.len(p4)))
+ sent, err = data:send(p1..p2..p3..p4)
+ if err then fail(err) end
+remote "data:send(str); data:close()"
+ bp1, err = data:receive()
+ if err then fail(err) end
+ bp2, err = data:receive()
+ if err then fail(err) end
+ bp3, err = data:receive(string.len(p3))
+ if err then fail(err) end
+ bp4, err = data:receive("*a")
+ if err then fail(err) end
+ if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then
+ pass("patterns match")
+ else fail("patterns don't match") end
+end
+
+------------------------------------------------------------------------
+function test_asciiline(len)
+ reconnect()
+ local str, str10, back, err
+ str = string.rep("x", math.mod(len, 10))
+ str10 = string.rep("aZb.c#dAe?", math.floor(len/10))
+ str = str .. str10
+remote "str = data:receive()"
+ sent, err = data:send(str.."\n")
+ if err then fail(err) end
+remote "data:send(str ..'\\n')"
+ back, err = data:receive()
+ if err then fail(err) end
+ if back == str then pass("lines match")
+ else fail("lines don't match") end
+end
+
+------------------------------------------------------------------------
+function test_rawline(len)
+ reconnect()
+ local str, str10, back, err
+ str = string.rep(string.char(47), math.mod(len, 10))
+ str10 = string.rep(string.char(120,21,77,4,5,0,7,36,44,100),
+ math.floor(len/10))
+ str = str .. str10
+remote "str = data:receive()"
+ sent, err = data:send(str.."\n")
+ if err then fail(err) end
+remote "data:send(str..'\\n')"
+ back, err = data:receive()
+ if err then fail(err) end
+ if back == str then pass("lines match")
+ else fail("lines don't match") end
+end
+
+------------------------------------------------------------------------
+function test_raw(len)
+ reconnect()
+ local half = math.floor(len/2)
+ local s1, s2, back, err
+ s1 = string.rep("x", half)
+ s2 = string.rep("y", len-half)
+remote (string.format("str = data:receive(%d)", len))
+ sent, err = data:send(s1)
+ if err then fail(err) end
+ sent, err = data:send(s2)
+ if err then fail(err) end
+remote "data:send(str)"
+ back, err = data:receive(len)
+ if err then fail(err) end
+ if back == s1..s2 then pass("blocks match")
+ else fail("blocks don't match") end
+end
+
+------------------------------------------------------------------------
+function test_totaltimeoutreceive(len, tm, sl)
+ reconnect()
+ local str, err, partial
+ pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl)
+ remote (string.format ([[
+ data:settimeout(%d)
+ str = string.rep('a', %d)
+ data:send(str)
+ print('server: sleeping for %ds')
+ socket.sleep(%d)
+ print('server: woke up')
+ data:send(str)
+ ]], 2*tm, len, sl, sl))
+ data:settimeout(tm, "total")
+local t = socket.gettime()
+ str, err, partial, elapsed = data:receive(2*len)
+ check_timeout(tm, sl, elapsed, err, "receive", "total",
+ string.len(str or partial) == 2*len)
+end
+
+------------------------------------------------------------------------
+function test_totaltimeoutsend(len, tm, sl)
+ reconnect()
+ local str, err, total
+ pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl)
+ remote (string.format ([[
+ data:settimeout(%d)
+ str = data:receive(%d)
+ print('server: sleeping for %ds')
+ socket.sleep(%d)
+ print('server: woke up')
+ str = data:receive(%d)
+ ]], 2*tm, len, sl, sl, len))
+ data:settimeout(tm, "total")
+ str = string.rep("a", 2*len)
+ total, err, partial, elapsed = data:send(str)
+ check_timeout(tm, sl, elapsed, err, "send", "total",
+ total == 2*len)
+end
+
+------------------------------------------------------------------------
+function test_blockingtimeoutreceive(len, tm, sl)
+ reconnect()
+ local str, err, partial
+ pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
+ remote (string.format ([[
+ data:settimeout(%d)
+ str = string.rep('a', %d)
+ data:send(str)
+ print('server: sleeping for %ds')
+ socket.sleep(%d)
+ print('server: woke up')
+ data:send(str)
+ ]], 2*tm, len, sl, sl))
+ data:settimeout(tm)
+ str, err, partial, elapsed = data:receive(2*len)
+ check_timeout(tm, sl, elapsed, err, "receive", "blocking",
+ string.len(str or partial) == 2*len)
+end
+
+------------------------------------------------------------------------
+function test_blockingtimeoutsend(len, tm, sl)
+ reconnect()
+ local str, err, total
+ pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
+ remote (string.format ([[
+ data:settimeout(%d)
+ str = data:receive(%d)
+ print('server: sleeping for %ds')
+ socket.sleep(%d)
+ print('server: woke up')
+ str = data:receive(%d)
+ ]], 2*tm, len, sl, sl, len))
+ data:settimeout(tm)
+ str = string.rep("a", 2*len)
+ total, err, partial, elapsed = data:send(str)
+ check_timeout(tm, sl, elapsed, err, "send", "blocking",
+ total == 2*len)
+end
+
+------------------------------------------------------------------------
+function empty_connect()
+ reconnect()
+ if data then data:close() data = nil end
+ remote [[
+ if data then data:close() data = nil end
+ data = server:accept()
+ ]]
+ data, err = socket.connect("", port)
+ if not data then
+ pass("ok")
+ data = socket.connect(host, port)
+ else
+ pass("gethostbyname returns localhost on empty string...")
+ end
+end
+
+------------------------------------------------------------------------
+function isclosed(c)
+ return c:getfd() == -1 or c:getfd() == (2^32-1)
+end
+
+function active_close()
+ reconnect()
+ if isclosed(data) then fail("should not be closed") end
+ data:close()
+ if not isclosed(data) then fail("should be closed") end
+ data = nil
+ local udp = socket.udp()
+ if isclosed(udp) then fail("should not be closed") end
+ udp:close()
+ if not isclosed(udp) then fail("should be closed") end
+ pass("ok")
+end
+
+------------------------------------------------------------------------
+function test_closed()
+ local back, partial, err
+ local str = 'little string'
+ reconnect()
+ pass("trying read detection")
+ remote (string.format ([[
+ data:send('%s')
+ data:close()
+ data = nil
+ ]], str))
+ -- try to get a line
+ back, err, partial = data:receive()
+ if not err then fail("should have gotten 'closed'.")
+ elseif err ~= "closed" then fail("got '"..err.."' instead of 'closed'.")
+ elseif str ~= partial then fail("didn't receive partial result.")
+ else pass("graceful 'closed' received") end
+ reconnect()
+ pass("trying write detection")
+ remote [[
+ data:close()
+ data = nil
+ ]]
+ total, err, partial = data:send(string.rep("ugauga", 100000))
+ if not err then
+ pass("failed: output buffer is at least %d bytes long!", total)
+ elseif err ~= "closed" then
+ fail("got '"..err.."' instead of 'closed'.")
+ else
+ pass("graceful 'closed' received after %d bytes were sent", partial)
+ end
+end
+
+------------------------------------------------------------------------
+function test_selectbugs()
+ local r, s, e = socket.select(nil, nil, 0.1)
+ assert(type(r) == "table" and type(s) == "table" and
+ (e == "timeout" or e == "error"))
+ pass("both nil: ok")
+ local udp = socket.udp()
+ udp:close()
+ r, s, e = socket.select({ udp }, { udp }, 0.1)
+ assert(type(r) == "table" and type(s) == "table" and
+ (e == "timeout" or e == "error"))
+ pass("closed sockets: ok")
+ e = pcall(socket.select, "wrong", 1, 0.1)
+ assert(e == false)
+ e = pcall(socket.select, {}, 1, 0.1)
+ assert(e == false)
+ pass("invalid input: ok")
+end
+
+------------------------------------------------------------------------
+function accept_timeout()
+ io.stderr:write("accept with timeout (if it hangs, it failed): ")
+ local s, e = socket.bind("*", 0, 0)
+ assert(s, e)
+ local t = socket.gettime()
+ s:settimeout(1)
+ local c, e = s:accept()
+ assert(not c, "should not accept")
+ assert(e == "timeout", string.format("wrong error message (%s)", e))
+ t = socket.gettime() - t
+ assert(t < 2, string.format("took to long to give up (%gs)", t))
+ s:close()
+ pass("good")
+end
+
+------------------------------------------------------------------------
+function connect_timeout()
+ io.stderr:write("connect with timeout (if it hangs, it failed!): ")
+ local t = socket.gettime()
+ local c, e = socket.tcp()
+ assert(c, e)
+ c:settimeout(0.1)
+ local t = socket.gettime()
+ local r, e = c:connect("127.0.0.2", 80)
+ assert(not r, "should not connect")
+ assert(socket.gettime() - t < 2, "took too long to give up.")
+ c:close()
+ print("ok")
+end
+
+------------------------------------------------------------------------
+function accept_errors()
+ io.stderr:write("not listening: ")
+ local d, e = socket.bind("*", 0)
+ assert(d, e);
+ local c, e = socket.tcp();
+ assert(c, e);
+ d:setfd(c:getfd())
+ d:settimeout(2)
+ local r, e = d:accept()
+ assert(not r and e)
+ print("ok: ", e)
+ io.stderr:write("not supported: ")
+ local c, e = socket.udp()
+ assert(c, e);
+ d:setfd(c:getfd())
+ local r, e = d:accept()
+ assert(not r and e)
+ print("ok: ", e)
+end
+
+------------------------------------------------------------------------
+function connect_errors()
+ io.stderr:write("connection refused: ")
+ local c, e = socket.connect("localhost", 1);
+ assert(not c and e)
+ print("ok: ", e)
+ io.stderr:write("host not found: ")
+ local c, e = socket.connect("host.is.invalid", 1);
+ assert(not c and e, e)
+ print("ok: ", e)
+end
+
+------------------------------------------------------------------------
+function rebind_test()
+ local c = socket.bind("localhost", 0)
+ local i, p = c:getsockname()
+ local s, e = socket.tcp()
+ assert(s, e)
+ s:setoption("reuseaddr", false)
+ r, e = s:bind("localhost", p)
+ assert(not r, "managed to rebind!")
+ assert(e)
+ print("ok: ", e)
+end
+
+------------------------------------------------------------------------
+function getstats_test()
+ reconnect()
+ local t = 0
+ for i = 1, 25 do
+ local c = math.random(1, 100)
+ remote (string.format ([[
+ str = data:receive(%d)
+ data:send(str)
+ ]], c))
+ data:send(string.rep("a", c))
+ data:receive(c)
+ t = t + c
+ local r, s, a = data:getstats()
+ assert(r == t, "received count failed" .. tostring(r)
+ .. "/" .. tostring(t))
+ assert(s == t, "sent count failed" .. tostring(s)
+ .. "/" .. tostring(t))
+ end
+ print("ok")
+end
+
+
+------------------------------------------------------------------------
+function test_nonblocking(size)
+ reconnect()
+print("Testing " .. 2*size .. " bytes")
+remote(string.format([[
+ data:send(string.rep("a", %d))
+ socket.sleep(0.5)
+ data:send(string.rep("b", %d) .. "\n")
+]], size, size))
+ local err = "timeout"
+ local part = ""
+ local str
+ data:settimeout(0)
+ while 1 do
+ str, err, part = data:receive("*l", part)
+ if err ~= "timeout" then break end
+ end
+ assert(str == (string.rep("a", size) .. string.rep("b", size)))
+ reconnect()
+remote(string.format([[
+ str = data:receive(%d)
+ socket.sleep(0.5)
+ str = data:receive(%d, str)
+ data:send(str)
+]], size, size))
+ data:settimeout(0)
+ local start = 0
+ while 1 do
+ ret, err, start = data:send(str, start+1)
+ if err ~= "timeout" then break end
+ end
+ data:send("\n")
+ data:settimeout(-1)
+ local back = data:receive(2*size)
+ assert(back == str, "'" .. back .. "' vs '" .. str .. "'")
+ print("ok")
+end
+
+------------------------------------------------------------------------
+
+test("method registration")
+test_methods(socket.unix(), {
+ "accept",
+ "bind",
+ "close",
+ "connect",
+ "dirty",
+ "getfd",
+ "getstats",
+ "setstats",
+ "listen",
+ "receive",
+ "send",
+ "setfd",
+ "setoption",
+ "setpeername",
+ "setsockname",
+ "settimeout",
+ "shutdown",
+})
+
+test("connect function")
+--connect_timeout()
+--empty_connect()
+--connect_errors()
+
+--test("rebinding: ")
+--rebind_test()
+
+test("active close: ")
+active_close()
+
+test("closed connection detection: ")
+test_closed()
+
+test("accept function: ")
+accept_timeout()
+accept_errors()
+
+test("getstats test")
+getstats_test()
+
+test("character line")
+test_asciiline(1)
+test_asciiline(17)
+test_asciiline(200)
+test_asciiline(4091)
+test_asciiline(80199)
+test_asciiline(8000000)
+test_asciiline(80199)
+test_asciiline(4091)
+test_asciiline(200)
+test_asciiline(17)
+test_asciiline(1)
+
+test("mixed patterns")
+test_mixed(1)
+test_mixed(17)
+test_mixed(200)
+test_mixed(4091)
+test_mixed(801990)
+test_mixed(4091)
+test_mixed(200)
+test_mixed(17)
+test_mixed(1)
+
+test("binary line")
+test_rawline(1)
+test_rawline(17)
+test_rawline(200)
+test_rawline(4091)
+test_rawline(80199)
+test_rawline(8000000)
+test_rawline(80199)
+test_rawline(4091)
+test_rawline(200)
+test_rawline(17)
+test_rawline(1)
+
+test("raw transfer")
+test_raw(1)
+test_raw(17)
+test_raw(200)
+test_raw(4091)
+test_raw(80199)
+test_raw(8000000)
+test_raw(80199)
+test_raw(4091)
+test_raw(200)
+test_raw(17)
+test_raw(1)
+
+test("non-blocking transfer")
+test_nonblocking(1)
+test_nonblocking(17)
+test_nonblocking(200)
+test_nonblocking(4091)
+test_nonblocking(80199)
+test_nonblocking(8000000)
+test_nonblocking(80199)
+test_nonblocking(4091)
+test_nonblocking(200)
+test_nonblocking(17)
+test_nonblocking(1)
+
+test("total timeout on send")
+test_totaltimeoutsend(800091, 1, 3)
+test_totaltimeoutsend(800091, 2, 3)
+test_totaltimeoutsend(800091, 5, 2)
+test_totaltimeoutsend(800091, 3, 1)
+
+test("total timeout on receive")
+test_totaltimeoutreceive(800091, 1, 3)
+test_totaltimeoutreceive(800091, 2, 3)
+test_totaltimeoutreceive(800091, 3, 2)
+test_totaltimeoutreceive(800091, 3, 1)
+
+test("blocking timeout on send")
+test_blockingtimeoutsend(800091, 1, 3)
+test_blockingtimeoutsend(800091, 2, 3)
+test_blockingtimeoutsend(800091, 3, 2)
+test_blockingtimeoutsend(800091, 3, 1)
+
+test("blocking timeout on receive")
+test_blockingtimeoutreceive(800091, 1, 3)
+test_blockingtimeoutreceive(800091, 2, 3)
+test_blockingtimeoutreceive(800091, 3, 2)
+test_blockingtimeoutreceive(800091, 3, 1)
+
+test(string.format("done in %.2fs", socket.gettime() - start))
diff --git a/test/utestsrvr.lua b/test/utestsrvr.lua
new file mode 100644
index 0000000..f7be196
--- /dev/null
+++ b/test/utestsrvr.lua
@@ -0,0 +1,17 @@
+require("socket");
+os.remove("/tmp/luasocket")
+socket = require("socket.unix");
+host = "luasocket";
+server = socket.unix()
+print(server:bind(host))
+print(server:listen(5))
+ack = "\n";
+while 1 do
+ print("server: waiting for client connection...");
+ control = assert(server:accept());
+ while 1 do
+ command = assert(control:receive());
+ assert(control:send(ack));
+ (loadstring(command))();
+ end
+end
--
cgit v1.2.3-55-g6feb