From 56e63b7a9edf56693944d3b05ae76e1d2be27991 Mon Sep 17 00:00:00 2001
From: tedu <>
Date: Thu, 30 Oct 2014 16:06:07 +0000
Subject: rework the poll loop to poll in both directions so it doesn't get
 stuck if one pipe stalls out. from a diff by Arne Becker. (buffer size left
 alone for now)

---
 src/usr.bin/nc/netcat.c | 258 +++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 211 insertions(+), 47 deletions(-)

diff --git a/src/usr.bin/nc/netcat.c b/src/usr.bin/nc/netcat.c
index e6ec97ed0f..a8e90186e9 100644
--- a/src/usr.bin/nc/netcat.c
+++ b/src/usr.bin/nc/netcat.c
@@ -1,4 +1,4 @@
-/* $OpenBSD: netcat.c,v 1.124 2014/10/26 13:59:30 millert Exp $ */
+/* $OpenBSD: netcat.c,v 1.125 2014/10/30 16:06:07 tedu Exp $ */
 /*
  * Copyright (c) 2001 Eric Jackson <ericj@monkey.org>
  *
@@ -64,6 +64,12 @@
 #define PORT_MAX_LEN	6
 #define UNIX_DG_TMP_SOCKET_SIZE	19
 
+#define POLL_STDIN 0
+#define POLL_NETOUT 1
+#define POLL_NETIN 2
+#define POLL_STDOUT 3
+#define BUFSIZE 2048
+
 /* Command Line Options */
 int	dflag;					/* detached, no stdin */
 int	Fflag;					/* fdpass sock to stdout */
@@ -111,6 +117,8 @@ void	set_common_sockopts(int);
 int	map_tos(char *, int *);
 void	report_connect(const struct sockaddr *, socklen_t);
 void	usage(int);
+ssize_t drainbuf(int, unsigned char *, size_t *);
+ssize_t fillbuf(int, unsigned char *, size_t *);
 
 int
 main(int argc, char *argv[])
@@ -390,7 +398,7 @@ main(int argc, char *argv[])
 				    &len);
 				if (connfd == -1) {
 					/* For now, all errnos are fatal */
-   					err(1, "accept");
+					err(1, "accept");
 				}
 				if (vflag)
 					report_connect((struct sockaddr *)&cliaddr, len);
@@ -729,68 +737,224 @@ local_listen(char *host, char *port, struct addrinfo hints)
  * Loop that polls on the network file descriptor and stdin.
  */
 void
-readwrite(int nfd)
+readwrite(int net_fd)
 {
-	struct pollfd pfd[2];
-	unsigned char buf[16 * 1024];
-	int n, wfd = fileno(stdin);
-	int lfd = fileno(stdout);
-	int plen;
-
-	plen = sizeof(buf);
-
-	/* Setup Network FD */
-	pfd[0].fd = nfd;
-	pfd[0].events = POLLIN;
-
-	/* Set up STDIN FD. */
-	pfd[1].fd = wfd;
-	pfd[1].events = POLLIN;
+	struct pollfd pfd[4];
+	int stdin_fd = STDIN_FILENO;
+	int stdout_fd = STDOUT_FILENO;
+	unsigned char netinbuf[BUFSIZE];
+	size_t netinbufpos = 0;
+	unsigned char stdinbuf[BUFSIZE];
+	size_t stdinbufpos = 0;
+	int n, num_fds, flags;
+	ssize_t ret;
+
+	/* don't read from stdin if requested */
+	if (dflag)
+		stdin_fd = -1;
+
+	/* stdin */
+	pfd[POLL_STDIN].fd = stdin_fd;
+	pfd[POLL_STDIN].events = POLLIN;
+
+	/* network out */
+	pfd[POLL_NETOUT].fd = net_fd;
+	pfd[POLL_NETOUT].events = 0;
+
+	/* network in */
+	pfd[POLL_NETIN].fd = net_fd;
+	pfd[POLL_NETIN].events = POLLIN;
+
+	/* stdout */
+	pfd[POLL_STDOUT].fd = stdout_fd;
+	pfd[POLL_STDOUT].events = 0;
+
+	while (1) {
+		/* both inputs are gone, buffers are empty, we are done */
+		if (pfd[POLL_STDIN].fd == -1 && pfd[POLL_NETIN].fd == -1
+		    && stdinbufpos == 0 && netinbufpos == 0) {
+			close(net_fd);
+			return;
+		}
+		/* both outputs are gone, we can't continue */
+		if (pfd[POLL_NETOUT].fd == -1 && pfd[POLL_STDOUT].fd == -1) {
+			close(net_fd);
+			return;
+		}
+		/* listen and net in gone, queues empty, done */
+		if (lflag && pfd[POLL_NETIN].fd == -1
+		    && stdinbufpos == 0 && netinbufpos == 0) {
+			close(net_fd);
+			return;
+		}
 
-	while (pfd[0].fd != -1) {
+		/* help says -i is for "wait between lines sent". We read and
+		 * write arbitrary amounts of data, and we don't want to start
+		 * scanning for newlines, so this is as good as it gets */
 		if (iflag)
 			sleep(iflag);
 
-		if ((n = poll(pfd, 2 - dflag, timeout)) < 0) {
-			int saved_errno = errno;
-			close(nfd);
-			errc(1, saved_errno, "Polling Error");
+		/* poll */
+		num_fds = poll(pfd, 4, timeout);
+
+		/* treat poll errors */
+		if (num_fds == -1) {
+			close(net_fd);
+			err(1, "polling error");
 		}
 
-		if (n == 0)
+		/* timeout happened */
+		if (num_fds == 0)
 			return;
 
-		if (pfd[0].revents & (POLLIN|POLLHUP)) {
-			if ((n = read(nfd, buf, plen)) < 0)
-				return;
-			else if (n == 0) {
-				shutdown(nfd, SHUT_RD);
-				pfd[0].fd = -1;
-				pfd[0].events = 0;
-			} else {
-				if (tflag)
-					atelnet(nfd, buf, n);
-				if (atomicio(vwrite, lfd, buf, n) != n)
-					return;
+		/* treat socket error conditions */
+		for (n = 0; n < 4; n++) {
+			if (pfd[n].revents & (POLLERR|POLLNVAL)) {
+				pfd[n].fd = -1;
 			}
 		}
+		/* reading is possible after HUP */
+		if (pfd[POLL_STDIN].events & POLLIN &&
+		    pfd[POLL_STDIN].revents & POLLHUP &&
+		    ! (pfd[POLL_STDIN].revents & POLLIN))
+				pfd[POLL_STDIN].fd = -1;
+
+		if (pfd[POLL_NETIN].events & POLLIN &&
+		    pfd[POLL_NETIN].revents & POLLHUP &&
+		    ! (pfd[POLL_NETIN].revents & POLLIN))
+				pfd[POLL_NETIN].fd = -1;
+
+		if (pfd[POLL_NETOUT].revents & POLLHUP) {
+			if (Nflag)
+				shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
+			pfd[POLL_NETOUT].fd = -1;
+		}
+		/* if HUP, stop watching stdout */
+		if (pfd[POLL_STDOUT].revents & POLLHUP)
+			pfd[POLL_STDOUT].fd = -1;
+		/* if no net out, stop watching stdin */
+		if (pfd[POLL_NETOUT].fd == -1)
+			pfd[POLL_STDIN].fd = -1;
+		/* if no stdout, stop watching net in */
+		if (pfd[POLL_STDOUT].fd == -1) {
+			if (pfd[POLL_NETIN].fd != -1)
+				shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
+			pfd[POLL_NETIN].fd = -1;
+		}
 
-		if (!dflag && pfd[1].revents & (POLLIN|POLLHUP)) {
-			if ((n = read(wfd, buf, plen)) < 0)
-				return;
-			else if (n == 0) {
-				if (Nflag)
-					shutdown(nfd, SHUT_WR);
-				pfd[1].fd = -1;
-				pfd[1].events = 0;
-			} else {
-				if (atomicio(vwrite, nfd, buf, n) != n)
-					return;
+		/* try to read from stdin */
+		if (pfd[POLL_STDIN].revents & POLLIN && stdinbufpos < BUFSIZE) {
+			ret = fillbuf(pfd[POLL_STDIN].fd, stdinbuf,
+			    &stdinbufpos);
+			/* error or eof on stdin - remove from pfd */
+			if (ret == 0 || ret == -1)
+				pfd[POLL_STDIN].fd = -1;
+			/* read something - poll net out */
+			if (stdinbufpos > 0)
+				pfd[POLL_NETOUT].events = POLLOUT;
+			/* filled buffer - remove self from polling */
+			if (stdinbufpos == BUFSIZE)
+				pfd[POLL_STDIN].events = 0;
+		}
+		/* try to write to network */
+		if (pfd[POLL_NETOUT].revents & POLLOUT && stdinbufpos > 0) {
+			ret = drainbuf(pfd[POLL_NETOUT].fd, stdinbuf,
+			    &stdinbufpos);
+			if (ret == -1)
+				pfd[POLL_NETOUT].fd = -1;
+			/* buffer empty - remove self from polling */
+			if (stdinbufpos == 0)
+				pfd[POLL_NETOUT].events = 0;
+			/* buffer no longer full - poll stdin again */
+			if (stdinbufpos < BUFSIZE)
+				pfd[POLL_STDIN].events = POLLIN;
+		}
+		/* try to read from network */
+		if (pfd[POLL_NETIN].revents & POLLIN && netinbufpos < BUFSIZE) {
+			ret = fillbuf(pfd[POLL_NETIN].fd, netinbuf,
+			    &netinbufpos);
+			if (ret == -1)
+				pfd[POLL_NETIN].fd = -1;
+			/* eof on net in - remove from pfd */
+			if (ret == 0) {
+				shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
+				pfd[POLL_NETIN].fd = -1;
 			}
+			/* read something - poll stdout */
+			if (netinbufpos > 0)
+				pfd[POLL_STDOUT].events = POLLOUT;
+			/* filled buffer - remove self from polling */
+			if (netinbufpos == BUFSIZE)
+				pfd[POLL_NETIN].events = 0;
+			/* handle telnet */
+			if (tflag)
+				atelnet(pfd[POLL_NETIN].fd, netinbuf,
+				    netinbufpos);
+		}
+		/* try to write to stdout */
+		if (pfd[POLL_STDOUT].revents & POLLOUT && netinbufpos > 0) {
+			ret = drainbuf(pfd[POLL_STDOUT].fd, netinbuf,
+			    &netinbufpos);
+			if (ret == -1)
+				pfd[POLL_STDOUT].fd = -1;
+			/* buffer empty - remove self from polling */
+			if (netinbufpos == 0)
+				pfd[POLL_STDOUT].events = 0;
+			/* buffer no longer full - poll net in again */
+			if (netinbufpos < BUFSIZE)
+				pfd[POLL_NETIN].events = POLLIN;
+		}
+
+		/* stdin gone and queue empty? */
+		if (pfd[POLL_STDIN].fd == -1 && stdinbufpos == 0) {
+			if (pfd[POLL_NETOUT].fd != -1 && Nflag)
+				shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
+			pfd[POLL_NETOUT].fd = -1;
+		}
+		/* net in gone and queue empty? */
+		if (pfd[POLL_NETIN].fd == -1 && netinbufpos == 0) {
+			pfd[POLL_STDOUT].fd = -1;
 		}
 	}
 }
 
+ssize_t
+drainbuf(int fd, unsigned char *buf, size_t *bufpos)
+{
+	ssize_t n;
+	ssize_t adjust;
+
+	n = write(fd, buf, *bufpos);
+	/* don't treat EAGAIN, EINTR as error */
+	if (n == -1 && (errno == EAGAIN || errno == EINTR))
+		n = -2;
+	if (n <= 0)
+		return n;
+	/* adjust buffer */
+	adjust = *bufpos - n;
+	if (adjust > 0)
+		memmove(buf, buf + n, adjust);
+	*bufpos -= n;
+	return n;
+}
+
+
+ssize_t
+fillbuf(int fd, unsigned char *buf, size_t *bufpos)
+{
+	size_t num = BUFSIZE - *bufpos;
+	ssize_t n;
+
+	n = read(fd, buf + *bufpos, num);
+	/* don't treat EAGAIN, EINTR as error */
+	if (n == -1 && (errno == EAGAIN || errno == EINTR))
+		n = -2;
+	if (n <= 0)
+		return n;
+	*bufpos += n;
+	return n;
+}
+
 /*
  * fdpass()
  * Pass the connected file descriptor to stdout and exit.
-- 
cgit v1.2.3-55-g6feb