diff --git a/src/trio/_tests/test_fakenet.py b/src/trio/_tests/test_fakenet.py index d250a105a3..f7f5c16b4d 100644 --- a/src/trio/_tests/test_fakenet.py +++ b/src/trio/_tests/test_fakenet.py @@ -1,10 +1,21 @@ import errno +import re +import socket +import sys import pytest import trio from trio.testing._fake_net import FakeNet +# ENOTCONN gives different messages on different platforms +if sys.platform == "linux": + ENOTCONN_MSG = "Transport endpoint is not connected" +elif sys.platform == "darwin": + ENOTCONN_MSG = "Socket is not connected" +else: + ENOTCONN_MSG = "Unknown error" + def fn() -> FakeNet: fn = FakeNet() @@ -26,6 +37,11 @@ async def test_basic_udp() -> None: await s1.bind(("192.0.2.1", 0)) assert exc.value.errno == errno.EINVAL + # Cannot bind multiple sockets to the same address + with pytest.raises(OSError) as exc: + await s2.bind(("127.0.0.1", port)) + assert exc.value.errno == errno.EADDRINUSE + await s2.sendto(b"xyz", s1.getsockname()) data, addr = await s1.recvfrom(10) assert data == b"xyz" @@ -45,7 +61,231 @@ async def test_msg_trunc() -> None: data, addr = await s1.recvfrom(10) +async def test_recv_methods() -> None: + """Test all recv methods for codecov""" + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + # receiving on an unbound socket is a bad idea (I think?) + with pytest.raises(NotImplementedError, match="code will most likely hang"): + await s2.recv(10) + + await s1.bind(("127.0.0.1", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + + # recvfrom + await s2.sendto(b"abc", s1.getsockname()) + data, addr = await s1.recvfrom(10) + assert data == b"abc" + assert addr == s2.getsockname() + + # recv + await s1.sendto(b"def", s2.getsockname()) + data = await s2.recv(10) + assert data == b"def" + + # recvfrom_into + assert await s1.sendto(b"ghi", s2.getsockname()) == 3 + buf = bytearray(10) + + with pytest.raises(NotImplementedError, match="partial recvfrom_into"): + (nbytes, addr) = await s2.recvfrom_into(buf, nbytes=2) + + (nbytes, addr) = await s2.recvfrom_into(buf) + assert nbytes == 3 + assert buf == b"ghi" + b"\x00" * 7 + assert addr == s1.getsockname() + + # recv_into + assert await s1.sendto(b"jkl", s2.getsockname()) == 3 + buf2 = bytearray(10) + nbytes = await s2.recv_into(buf2) + assert nbytes == 3 + assert buf2 == b"jkl" + b"\x00" * 7 + + if sys.platform == "linux" and sys.implementation.name == "cpython": + flags: int = socket.MSG_MORE + else: + flags = 1 + + # Send seems explicitly non-functional + with pytest.raises(OSError, match=ENOTCONN_MSG) as exc: + await s2.send(b"mno") + assert exc.value.errno == errno.ENOTCONN + with pytest.raises(NotImplementedError, match="FakeNet send flags must be 0, not"): + await s2.send(b"mno", flags) + + # sendto errors + # it's successfully used earlier + with pytest.raises(NotImplementedError, match="FakeNet send flags must be 0, not"): + await s2.sendto(b"mno", flags, s1.getsockname()) + with pytest.raises(TypeError, match="wrong number of arguments"): + await s2.sendto(b"mno", flags, s1.getsockname(), "extra arg") # type: ignore[call-overload] + + +@pytest.mark.skipif( + sys.platform == "win32", reason="functions not in socket on windows" +) +async def test_nonwindows_functionality() -> None: + # mypy doesn't support a good way of aborting typechecking on different platforms + if sys.platform != "win32": # pragma: no branch + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s2.bind(("127.0.0.1", 0)) + + # sendmsg + with pytest.raises(OSError, match=ENOTCONN_MSG) as exc: + await s2.sendmsg([b"mno"]) + assert exc.value.errno == errno.ENOTCONN + + assert await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) == 3 + (data, ancdata, msg_flags, addr) = await s2.recvmsg(10) + assert data == b"jkl" + assert ancdata == [] + assert msg_flags == 0 + assert addr == s1.getsockname() + + # TODO: recvmsg + + # recvmsg_into + assert await s1.sendto(b"xyzw", s2.getsockname()) == 4 + buf1 = bytearray(2) + buf2 = bytearray(3) + ret = await s2.recvmsg_into([buf1, buf2]) + (nbytes, ancdata, msg_flags, addr) = ret + assert nbytes == 4 + assert buf1 == b"xy" + assert buf2 == b"zw" + b"\x00" + assert ancdata == [] + assert msg_flags == 0 + assert addr == s1.getsockname() + + # recvmsg_into with MSG_TRUNC set + assert await s1.sendto(b"xyzwv", s2.getsockname()) == 5 + buf1 = bytearray(2) + ret = await s2.recvmsg_into([buf1]) + (nbytes, ancdata, msg_flags, addr) = ret + assert nbytes == 2 + assert buf1 == b"xy" + assert ancdata == [] + assert msg_flags == socket.MSG_TRUNC + assert addr == s1.getsockname() + + with pytest.raises( + AttributeError, match="'FakeSocket' object has no attribute 'share'" + ): + await s1.share(0) # type: ignore[attr-defined] + + +@pytest.mark.skipif( + sys.platform != "win32", reason="windows-specific fakesocket testing" +) +async def test_windows_functionality() -> None: + # mypy doesn't support a good way of aborting typechecking on different platforms + if sys.platform == "win32": # pragma: no branch + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s1.bind(("127.0.0.1", 0)) + with pytest.raises( + AttributeError, match="'FakeSocket' object has no attribute 'sendmsg'" + ): + await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) # type: ignore[attr-defined] + with pytest.raises( + AttributeError, match="'FakeSocket' object has no attribute 'recvmsg'" + ): + s2.recvmsg(0) # type: ignore[attr-defined] + with pytest.raises( + AttributeError, + match="'FakeSocket' object has no attribute 'recvmsg_into'", + ): + s2.recvmsg_into([]) # type: ignore[attr-defined] + with pytest.raises(NotImplementedError): + s1.share(0) + + async def test_basic_tcp() -> None: fn() with pytest.raises(NotImplementedError): trio.socket.socket() + + +async def test_not_implemented_functions() -> None: + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + # getsockopt + with pytest.raises(OSError, match="FakeNet doesn't implement getsockopt"): + s1.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + # setsockopt + with pytest.raises( + NotImplementedError, match="FakeNet always has IPV6_V6ONLY=True" + ): + s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + with pytest.raises(OSError, match="FakeNet doesn't implement setsockopt"): + s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) + with pytest.raises(OSError, match="FakeNet doesn't implement setsockopt"): + s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + # set_inheritable + s1.set_inheritable(False) + with pytest.raises( + NotImplementedError, match="FakeNet can't make inheritable sockets" + ): + s1.set_inheritable(True) + + # get_inheritable + assert not s1.get_inheritable() + + +async def test_getpeername() -> None: + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + with pytest.raises(OSError, match=ENOTCONN_MSG) as exc: + s1.getpeername() + assert exc.value.errno == errno.ENOTCONN + + await s1.bind(("127.0.0.1", 0)) + + with pytest.raises( + AssertionError, + match="This method seems to assume that self._binding has a remote UDPEndpoint", + ): + s1.getpeername() + + +async def test_init() -> None: + fn() + with pytest.raises( + NotImplementedError, + match=re.escape( + f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}" + ), + ): + s1 = trio.socket.socket() + + # getsockname on unbound ipv4 socket + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + assert s1.getsockname() == ("0.0.0.0", 0) + + # getsockname on bound ipv4 socket + await s1.bind(("0.0.0.0", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + + # getsockname on unbound ipv6 socket + s2 = trio.socket.socket(family=socket.AF_INET6, type=socket.SOCK_DGRAM) + assert s2.getsockname() == ("::", 0) + + # getsockname on bound ipv6 socket + await s2.bind(("::", 0)) + ip, port, *_ = s2.getsockname() + assert ip == "::1" + assert port != 0 + assert _ == [0, 0] diff --git a/src/trio/testing/_fake_net.py b/src/trio/testing/_fake_net.py index 622d5e713e..cadf437a2e 100644 --- a/src/trio/testing/_fake_net.py +++ b/src/trio/testing/_fake_net.py @@ -12,6 +12,8 @@ import errno import ipaddress import os +import socket +import sys from typing import ( TYPE_CHECKING, Any, @@ -53,12 +55,13 @@ def _wildcard_ip_for(family: int) -> IPAddress: raise NotImplementedError("Unhandled ip address family") # pragma: no cover -def _localhost_ip_for(family: int) -> IPAddress: +# not used anywhere +def _localhost_ip_for(family: int) -> IPAddress: # pragma: no cover if family == trio.socket.AF_INET: return ipaddress.ip_address("127.0.0.1") elif family == trio.socket.AF_INET6: return ipaddress.ip_address("::1") - raise NotImplementedError("Unhandled ip address family") # pragma: no cover + raise NotImplementedError("Unhandled ip address family") def _fake_err(code: int) -> NoReturn: @@ -67,12 +70,12 @@ def _fake_err(code: int) -> NoReturn: def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int: written = 0 - for buf in buffers: + for buf in buffers: # pragma: no branch next_piece = data[written : written + memoryview(buf).nbytes] with memoryview(buf) as mbuf: mbuf[: len(next_piece)] = next_piece written += len(next_piece) - if written == len(data): + if written == len(data): # pragma: no branch break return written @@ -114,7 +117,8 @@ class UDPPacket: destination: UDPEndpoint payload: bytes = attr.ib(repr=lambda p: p.hex()) - def reply(self, payload: bytes) -> UDPPacket: + # not used/tested anywhere + def reply(self, payload: bytes) -> UDPPacket: # pragma: no cover return UDPPacket( source=self.destination, destination=self.source, payload=payload ) @@ -161,8 +165,8 @@ async def getnameinfo( class FakeNet: def __init__(self) -> None: # When we need to pick an arbitrary unique ip address/port, use these: - self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() - self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() # type: ignore[assignment] + self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() # untested + self._auto_ipv6_iter = ipaddress.IPv6Network("1::/16").hosts() # untested self._auto_port_iter = iter(range(50000, 65535)) self._bound: dict[UDPBinding, FakeSocket] = {} @@ -200,9 +204,9 @@ def __init__( ): self._fake_net = fake_net - if not family: + if not family: # pragma: no cover family = trio.socket.AF_INET - if not type: + if not type: # pragma: no cover type = trio.socket.SOCK_STREAM if family not in (trio.socket.AF_INET, trio.socket.AF_INET6): @@ -240,7 +244,6 @@ def _check_closed(self) -> None: _fake_err(errno.EBADF) def close(self) -> None: - # breakpoint() if self._closed: return self._closed = True @@ -274,7 +277,9 @@ async def bind(self, addr: object) -> None: if self._binding is not None: _fake_err(errno.EINVAL) await trio.lowlevel.checkpoint() - ip_str, port = await self._resolve_address_nocp(addr, local=True) + ip_str, port, *_ = await self._resolve_address_nocp(addr, local=True) + assert _ == [], "TODO: handle other values?" + ip = ipaddress.ip_address(ip_str) assert _family_for(ip) == self.family # We convert binds to INET_ANY into binds to localhost @@ -291,25 +296,14 @@ async def bind(self, addr: object) -> None: async def connect(self, peer: object) -> NoReturn: raise NotImplementedError("FakeNet does not (yet) support connected sockets") - async def sendmsg(self, *args: Any) -> int: + async def _sendmsg( + self, + buffers: Iterable[Buffer], + ancdata: Iterable[tuple[int, int, Buffer]] = (), + flags: int = 0, + address: Any | None = None, + ) -> int: self._check_closed() - ancdata = [] - flags = 0 - address = None - - # This does *not* match up with socket.socket.sendmsg (!!!) - # https://docs.python.org/3/library/socket.html#socket.socket.sendmsg - # they always have (buffers, ancdata, flags, address) - if len(args) == 1: - (buffers,) = args - elif len(args) == 2: - buffers, address = args - elif len(args) == 3: - buffers, flags, address = args - elif len(args) == 4: - buffers, ancdata, flags, address = args - else: - raise TypeError("wrong number of arguments") await trio.lowlevel.checkpoint() @@ -341,7 +335,12 @@ async def sendmsg(self, *args: Any) -> int: return len(payload) - async def recvmsg_into( + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + sendmsg = _sendmsg + + async def _recvmsg_into( self, buffers: Iterable[Buffer], ancbufsize: int = 0, @@ -351,6 +350,14 @@ async def recvmsg_into( raise NotImplementedError("FakeNet doesn't support ancillary data") if flags != 0: raise NotImplementedError("FakeNet doesn't support any recv flags") + if self._binding is None: + # I messed this up a few times when writing tests ... but it also never happens + # in any of the existing tests, so maybe it could be intentional... + raise NotImplementedError( + "The code will most likely hang if you try to receive on a fakesocket " + "without a binding. If that is not the case, or you explicitly want to " + "test that, remove this warning." + ) self._check_closed() @@ -364,6 +371,11 @@ async def recvmsg_into( msg_flags |= trio.socket.MSG_TRUNC return written, ancdata, msg_flags, address + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + recvmsg_into = _recvmsg_into + ################################################################ # Simple state query stuff ################################################################ @@ -385,7 +397,7 @@ def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]: assert hasattr( self._binding, "remote" ), "This method seems to assume that self._binding has a remote UDPEndpoint" - if self._binding.remote is not None: + if self._binding.remote is not None: # pragma: no cover assert isinstance( self._binding.remote, UDPEndpoint ), "Self._binding.remote should be a UDPEndpoint" @@ -450,7 +462,25 @@ def __exit__( async def send(self, data: Buffer, flags: int = 0) -> int: return await self.sendto(data, flags, None) + @overload + async def sendto( + self, __data: Buffer, __address: tuple[object, ...] | str | Buffer + ) -> int: + ... + + @overload + async def sendto( + self, + __data: Buffer, + __flags: int, + __address: tuple[object, ...] | str | None | Buffer, + ) -> int: + ... + async def sendto(self, *args: Any) -> int: + data: Buffer + flags: int + address: tuple[object, ...] | str | Buffer if len(args) == 2: data, address = args flags = 0 @@ -458,7 +488,7 @@ async def sendto(self, *args: Any) -> int: data, flags, address = args else: raise TypeError("wrong number of arguments") - return await self.sendmsg([data], [], flags, address) + return await self._sendmsg([data], [], flags, address) async def recv(self, bufsize: int, flags: int = 0) -> bytes: data, address = await self.recvfrom(bufsize, flags) @@ -469,7 +499,7 @@ async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int: return got_bytes async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]: - data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags) + data, ancdata, msg_flags, address = await self._recvmsg(bufsize, flags) return data, address async def recvfrom_into( @@ -477,20 +507,25 @@ async def recvfrom_into( ) -> tuple[int, Any]: if nbytes != 0 and nbytes != memoryview(buf).nbytes: raise NotImplementedError("partial recvfrom_into") - got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( [buf], 0, flags ) return got_nbytes, address - async def recvmsg( + async def _recvmsg( self, bufsize: int, ancbufsize: int = 0, flags: int = 0 ) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]: buf = bytearray(bufsize) - got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( [buf], ancbufsize, flags ) return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + recvmsg = _recvmsg + def fileno(self) -> int: raise NotImplementedError("can't get fileno() for FakeNet sockets") @@ -504,5 +539,9 @@ def set_inheritable(self, inheritable: bool) -> None: if inheritable: raise NotImplementedError("FakeNet can't make inheritable sockets") - def share(self, process_id: int) -> bytes: - raise NotImplementedError("FakeNet can't share sockets") + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "share") + ): + + def share(self, process_id: int) -> bytes: + raise NotImplementedError("FakeNet can't share sockets")