Sponge
CS144's user-space TCP library
socket.cc
Go to the documentation of this file.
1 #include "socket.hh"
2 
3 #include "util.hh"
4 
5 #include <cstddef>
6 #include <stdexcept>
7 #include <unistd.h>
8 
9 using namespace std;
10 
11 // default constructor for socket of (subclassed) domain and type
14 Socket::Socket(const int domain, const int type) : FileDescriptor(SystemCall("socket", socket(domain, type, 0))) {}
15 
16 // construct from file descriptor
20 Socket::Socket(FileDescriptor &&fd, const int domain, const int type) : FileDescriptor(move(fd)) {
21  int actual_value;
22  socklen_t len;
23 
24  // verify domain
25  len = sizeof(actual_value);
26  SystemCall("getsockopt", getsockopt(fd_num(), SOL_SOCKET, SO_DOMAIN, &actual_value, &len));
27  if ((len != sizeof(actual_value)) or (actual_value != domain)) {
28  throw runtime_error("socket domain mismatch");
29  }
30 
31  // verify type
32  len = sizeof(actual_value);
33  SystemCall("getsockopt", getsockopt(fd_num(), SOL_SOCKET, SO_TYPE, &actual_value, &len));
34  if ((len != sizeof(actual_value)) or (actual_value != type)) {
35  throw runtime_error("socket type mismatch");
36  }
37 }
38 
39 // get the local or peer address the socket is connected to
43 Address Socket::get_address(const string &name_of_function,
44  const function<int(int, sockaddr *, socklen_t *)> &function) const {
45  Address::Raw address;
46  socklen_t size = sizeof(address);
47 
48  SystemCall(name_of_function, function(fd_num(), address, &size));
49 
50  return {address, size};
51 }
52 
54 Address Socket::local_address() const { return get_address("getsockname", getsockname); }
55 
57 Address Socket::peer_address() const { return get_address("getpeername", getpeername); }
58 
59 // bind socket to a specified local address (usually to listen/accept)
61 void Socket::bind(const Address &address) { SystemCall("bind", ::bind(fd_num(), address, address.size())); }
62 
63 // connect socket to a specified peer address
65 void Socket::connect(const Address &address) { SystemCall("connect", ::connect(fd_num(), address, address.size())); }
66 
67 // shut down a socket in the specified way
69 void Socket::shutdown(const int how) {
70  SystemCall("shutdown", ::shutdown(fd_num(), how));
71  switch (how) {
72  case SHUT_RD:
73  register_read();
74  break;
75  case SHUT_WR:
77  break;
78  case SHUT_RDWR:
79  register_read();
81  break;
82  default:
83  throw runtime_error("Socket::shutdown() called with invalid `how`");
84  }
85 }
86 
88 void UDPSocket::recv(received_datagram &datagram, const size_t mtu) {
89  // receive source address and payload
90  Address::Raw datagram_source_address;
91  datagram.payload.resize(mtu);
92 
93  socklen_t fromlen = sizeof(datagram_source_address);
94 
95  const ssize_t recv_len = SystemCall(
96  "recvfrom",
97  ::recvfrom(
98  fd_num(), datagram.payload.data(), datagram.payload.size(), MSG_TRUNC, datagram_source_address, &fromlen));
99 
100  if (recv_len > ssize_t(mtu)) {
101  throw runtime_error("recvfrom (oversized datagram)");
102  }
103 
104  register_read();
105  datagram.source_address = {datagram_source_address, fromlen};
106  datagram.payload.resize(recv_len);
107 }
108 
110  received_datagram ret{{nullptr, 0}, ""};
111  recv(ret, mtu);
112  return ret;
113 }
114 
115 void sendmsg_helper(const int fd_num,
116  const sockaddr *destination_address,
117  const socklen_t destination_address_len,
118  const BufferViewList &payload) {
119  auto iovecs = payload.as_iovecs();
120 
121  msghdr message{};
122  message.msg_name = const_cast<sockaddr *>(destination_address);
123  message.msg_namelen = destination_address_len;
124  message.msg_iov = iovecs.data();
125  message.msg_iovlen = iovecs.size();
126 
127  const ssize_t bytes_sent = SystemCall("sendmsg", ::sendmsg(fd_num, &message, 0));
128 
129  if (size_t(bytes_sent) != payload.size()) {
130  throw runtime_error("datagram payload too big for sendmsg()");
131  }
132 }
133 
134 void UDPSocket::sendto(const Address &destination, const BufferViewList &payload) {
135  sendmsg_helper(fd_num(), destination, destination.size(), payload);
136  register_write();
137 }
138 
139 void UDPSocket::send(const BufferViewList &payload) {
140  sendmsg_helper(fd_num(), nullptr, 0, payload);
141  register_write();
142 }
143 
144 // mark the socket as listening for incoming connections
146 void TCPSocket::listen(const int backlog) { SystemCall("listen", ::listen(fd_num(), backlog)); }
147 
148 // accept a new incoming connection
152  register_read();
153  return TCPSocket(FileDescriptor(SystemCall("accept", ::accept(fd_num(), nullptr, nullptr))));
154 }
155 
156 // set socket option
161 template <typename option_type>
162 void Socket::setsockopt(const int level, const int option, const option_type &option_value) {
163  SystemCall("setsockopt", ::setsockopt(fd_num(), level, option, &option_value, sizeof(option_value)));
164 }
165 
166 // allow local address to be reused sooner, at the cost of some robustness
168 void Socket::set_reuseaddr() { setsockopt(SOL_SOCKET, SO_REUSEADDR, int(true)); }
Wrapper around sockaddr_storage.
Definition: address.hh:17
Wrapper around IPv4 addresses and DNS operations.
Definition: address.hh:13
socklen_t size() const
Size of the underlying address storage.
Definition: address.hh:66
A non-owning temporary view (similar to std::string_view) of a discontiguous string.
Definition: buffer.hh:98
size_t size() const
Size of the string.
Definition: buffer.cc:89
std::vector< iovec > as_iovecs() const
Convert to a vector of iovec structures.
Definition: buffer.cc:97
A reference-counted handle to a file descriptor.
int fd_num() const
underlying descriptor number
void register_read()
increment read count
FileDescriptor(std::shared_ptr< FDWrapper > other_shared_ptr)
Private constructor used by duplicate()
void register_write()
increment write count
void connect(const Address &address)
Connect a socket to a specified peer address with connect(2).
Definition: socket.cc:65
void shutdown(const int how)
Shut down a socket via shutdown(2).
Definition: socket.cc:69
Address get_address(const std::string &name_of_function, const std::function< int(int, sockaddr *, socklen_t *)> &function) const
Get the local or peer address the socket is connected to.
Definition: socket.cc:43
Socket(const int domain, const int type)
Construct via socket(2).
Definition: socket.cc:14
void setsockopt(const int level, const int option, const option_type &option_value)
Wrapper around setsockopt(2).
Definition: socket.cc:162
void bind(const Address &address)
Bind a socket to a specified address with bind(2), usually for listen/accept.
Definition: socket.cc:61
Address local_address() const
Get local address of socket with getsockname(2).
Definition: socket.cc:54
Address peer_address() const
Get peer address of socket with getpeername(2).
Definition: socket.cc:57
void set_reuseaddr()
Allow local address to be reused sooner via SO_REUSEADDR.
Definition: socket.cc:168
A wrapper around TCP sockets.
Definition: socket.hh:88
TCPSocket accept()
Accept a new incoming connection.
Definition: socket.cc:151
TCPSocket()
Default: construct an unbound, unconnected TCP socket.
Definition: socket.hh:96
void listen(const int backlog=16)
Mark a socket as listening for incoming connections.
Definition: socket.cc:146
received_datagram recv(const size_t mtu=65536)
Receive a datagram and the Address of its sender.
void send(const BufferViewList &payload)
Send datagram to the socket's connected address (must call connect() first)
Definition: socket.cc:139
void sendto(const Address &destination, const BufferViewList &payload)
Send a datagram to specified Address.
Definition: socket.cc:134
Returned by UDPSocket::recv; carries received data and information about the sender.
Definition: socket.hh:62
T move(T... args)
void sendmsg_helper(const int fd_num, const sockaddr *destination_address, const socklen_t destination_address_len, const BufferViewList &payload)
Definition: socket.cc:115
SystemCall("socketpair", ::socketpair(AF_UNIX, SOCK_STREAM, 0, fds.data()))