Skip to content

ssl, socket, array: Improve bytes handling #8997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions stdlib/_socket.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ _CMSG: TypeAlias = tuple[int, int, bytes]
_CMSGArg: TypeAlias = tuple[int, int, ReadableBuffer]

# Addresses can be either tuples of varying lengths (AF_INET, AF_INET6,
# AF_NETLINK, AF_TIPC) or strings (AF_UNIX).
_Address: TypeAlias = tuple[Any, ...] | str
# AF_NETLINK, AF_TIPC) or strings/buffers (AF_UNIX).
# See getsockaddrarg() in socketmodule.c.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more case is bytes only for BTPROTO_HCI on NetBSD, but I don't think this exotic case is worth mentioning.

_Address: TypeAlias = tuple[Any, ...] | str | ReadableBuffer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All uses of this alias in the socket stubs allow a ReadableBuffer, so I put ReadableBuffer in the alias and removed explicit usage of bytes below.

The alias is used in some other stubs, but we can review those in a different PR.

_RetAddress: TypeAlias = Any
# TODO Most methods allow bytes as address objects

# ----- Constants -----
# Some socket families are listed in the "Socket families" section of the docs,
Expand Down Expand Up @@ -584,10 +584,10 @@ class socket:
@property
def timeout(self) -> float | None: ...
def __init__(self, family: int = ..., type: int = ..., proto: int = ..., fileno: _FD | None = ...) -> None: ...
def bind(self, __address: _Address | bytes) -> None: ...
def bind(self, __address: _Address) -> None: ...
def close(self) -> None: ...
def connect(self, __address: _Address | bytes) -> None: ...
def connect_ex(self, __address: _Address | bytes) -> int: ...
def connect(self, __address: _Address) -> None: ...
def connect_ex(self, __address: _Address) -> int: ...
def detach(self) -> int: ...
def fileno(self) -> int: ...
def getpeername(self) -> _RetAddress: ...
Expand Down Expand Up @@ -634,7 +634,7 @@ class socket:
def setblocking(self, __flag: bool) -> None: ...
def settimeout(self, __value: float | None) -> None: ...
@overload
def setsockopt(self, __level: int, __optname: int, __value: int | bytes) -> None: ...
def setsockopt(self, __level: int, __optname: int, __value: int | ReadableBuffer) -> None: ...
@overload
def setsockopt(self, __level: int, __optname: int, __value: None, __optlen: int) -> None: ...
if sys.platform == "win32":
Expand Down Expand Up @@ -671,9 +671,9 @@ def ntohs(__x: int) -> int: ... # param & ret val are 16-bit ints
def htonl(__x: int) -> int: ... # param & ret val are 32-bit ints
def htons(__x: int) -> int: ... # param & ret val are 16-bit ints
def inet_aton(__ip_string: str) -> bytes: ... # ret val 4 bytes in length
def inet_ntoa(__packed_ip: bytes) -> str: ...
def inet_ntoa(__packed_ip: ReadableBuffer) -> str: ...
def inet_pton(__address_family: int, __ip_string: str) -> bytes: ...
def inet_ntop(__address_family: int, __packed_ip: bytes) -> str: ...
def inet_ntop(__address_family: int, __packed_ip: ReadableBuffer) -> str: ...
def getdefaulttimeout() -> float | None: ...
def setdefaulttimeout(__timeout: float | None) -> None: ...

Expand Down
12 changes: 8 additions & 4 deletions stdlib/array.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ class array(MutableSequence[_T], Generic[_T]):
@property
def itemsize(self) -> int: ...
@overload
def __init__(self: array[int], __typecode: _IntTypeCode, __initializer: bytes | Iterable[int] = ...) -> None: ...
def __init__(self: array[int], __typecode: _IntTypeCode, __initializer: bytes | bytearray | Iterable[int] = ...) -> None: ...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bytes and bytearray specifically are treated differently. Arbitrary non-iterable buffers (e.g., pickle.PickleBuffer) don't work.

@overload
def __init__(self: array[float], __typecode: _FloatTypeCode, __initializer: bytes | Iterable[float] = ...) -> None: ...
def __init__(
self: array[float], __typecode: _FloatTypeCode, __initializer: bytes | bytearray | Iterable[float] = ...
) -> None: ...
@overload
def __init__(self: array[str], __typecode: _UnicodeTypeCode, __initializer: bytes | Iterable[str] = ...) -> None: ...
def __init__(
self: array[str], __typecode: _UnicodeTypeCode, __initializer: bytes | bytearray | Iterable[str] = ...
) -> None: ...
@overload
def __init__(self, __typecode: str, __initializer: Iterable[_T]) -> None: ...
@overload
def __init__(self, __typecode: str, __initializer: bytes = ...) -> None: ...
def __init__(self, __typecode: str, __initializer: bytes | bytearray = ...) -> None: ...
def append(self, __v: _T) -> None: ...
def buffer_info(self) -> tuple[int, int]: ...
def byteswap(self) -> None: ...
Expand Down
15 changes: 9 additions & 6 deletions stdlib/socket.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ if sys.platform != "win32":
if sys.version_info >= (3, 9):
# flags and address appear to be unused in send_fds and recv_fds
def send_fds(
sock: socket, buffers: Iterable[bytes], fds: bytes | Iterable[int], flags: int = ..., address: None = ...
sock: socket, buffers: Iterable[ReadableBuffer], fds: Iterable[int], flags: int = ..., address: None = ...
) -> int: ...
def recv_fds(sock: socket, bufsize: int, maxfds: int, flags: int = ...) -> tuple[bytes, list[int], int, Any]: ...

Expand Down Expand Up @@ -768,16 +768,14 @@ if sys.version_info >= (3, 11):
def create_connection(
address: tuple[str | None, int],
timeout: float | None = ..., # noqa: F811
source_address: tuple[bytearray | bytes | str, int] | None = ...,
source_address: _Address | None = ...,
*,
all_errors: bool = ...,
) -> socket: ...

else:
def create_connection(
address: tuple[str | None, int],
timeout: float | None = ..., # noqa: F811
source_address: tuple[bytearray | bytes | str, int] | None = ...,
address: tuple[str | None, int], timeout: float | None = ..., source_address: _Address | None = ... # noqa: F811
) -> socket: ...

if sys.version_info >= (3, 8):
Expand All @@ -788,5 +786,10 @@ if sys.version_info >= (3, 8):

# the 5th tuple item is an address
def getaddrinfo(
host: bytes | str | None, port: str | int | None, family: int = ..., type: int = ..., proto: int = ..., flags: int = ...
host: bytes | str | None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It explicitly accepts only bytes, not bytearray.

port: bytes | str | int | None,
family: int = ...,
type: int = ...,
proto: int = ...,
flags: int = ...,
) -> list[tuple[AddressFamily, SocketKind, int, str, tuple[str, int] | tuple[str, int, int, int]]]: ...
46 changes: 27 additions & 19 deletions stdlib/ssl.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ _PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
_PeerCertRetDictType: TypeAlias = dict[str, str | _PCTRTTT | _PCTRTT]
_PeerCertRetType: TypeAlias = _PeerCertRetDictType | bytes | None
_EnumRetType: TypeAlias = list[tuple[bytes, str, set[str] | bool]]
_PasswordType: TypeAlias = Union[Callable[[], str | bytes], str, bytes]
_PasswordType: TypeAlias = Union[Callable[[], str | bytes | bytearray], str, bytes, bytearray]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The handling code handles bytes and bytearray, but not arbitrary buffers.


_SrvnmeCbType: TypeAlias = Callable[[SSLSocket | SSLObject, str | None, SSLSocket], int | None]

Expand Down Expand Up @@ -61,7 +61,7 @@ def create_default_context(
*,
cafile: StrOrBytesPath | None = ...,
capath: StrOrBytesPath | None = ...,
cadata: str | bytes | None = ...,
cadata: str | ReadableBuffer | None = ...,
) -> SSLContext: ...
def _create_unverified_context(
protocol: int = ...,
Expand All @@ -73,7 +73,7 @@ def _create_unverified_context(
keyfile: StrOrBytesPath | None = ...,
cafile: StrOrBytesPath | None = ...,
capath: StrOrBytesPath | None = ...,
cadata: str | bytes | None = ...,
cadata: str | ReadableBuffer | None = ...,
) -> SSLContext: ...

_create_default_https_context: Callable[..., SSLContext]
Expand All @@ -82,8 +82,11 @@ def RAND_bytes(__num: int) -> bytes: ...
def RAND_pseudo_bytes(__num: int) -> tuple[bytes, bool]: ...
def RAND_status() -> bool: ...
def RAND_egd(path: str) -> None: ...
def RAND_add(__s: bytes, __entropy: float) -> None: ...
def match_hostname(cert: _PeerCertRetType, hostname: str) -> None: ...
def RAND_add(__string: str | ReadableBuffer, __entropy: float) -> None: ...

if sys.version_info < (3, 12):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed this one is gone in 3.12.

def match_hostname(cert: _PeerCertRetDictType, hostname: str) -> None: ...

def cert_time_to_seconds(cert_time: str) -> int: ...

if sys.version_info >= (3, 10):
Expand All @@ -94,7 +97,7 @@ if sys.version_info >= (3, 10):
else:
def get_server_certificate(addr: tuple[str, int], ssl_version: int = ..., ca_certs: str | None = ...) -> str: ...

def DER_cert_to_PEM_cert(der_cert_bytes: bytes) -> str: ...
def DER_cert_to_PEM_cert(der_cert_bytes: ReadableBuffer) -> str: ...
def PEM_cert_to_DER_cert(pem_cert_string: str) -> bytes: ...

class DefaultVerifyPaths(NamedTuple):
Expand Down Expand Up @@ -290,8 +293,8 @@ class SSLSocket(socket.socket):
@property
def session_reused(self) -> bool | None: ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def connect(self, addr: socket._Address | bytes) -> None: ...
def connect_ex(self, addr: socket._Address | bytes) -> int: ...
def connect(self, addr: socket._Address) -> None: ...
def connect_ex(self, addr: socket._Address) -> int: ...
def recv(self, buflen: int = ..., flags: int = ...) -> bytes: ...
def recv_into(self, buffer: WriteableBuffer, nbytes: int | None = ..., flags: int = ...) -> int: ...
def recvfrom(self, buflen: int = ..., flags: int = ...) -> tuple[bytes, socket._RetAddress]: ...
Expand All @@ -301,12 +304,12 @@ class SSLSocket(socket.socket):
def send(self, data: ReadableBuffer, flags: int = ...) -> int: ...
def sendall(self, data: ReadableBuffer, flags: int = ...) -> None: ...
@overload
def sendto(self, data: ReadableBuffer, flags_or_addr: socket._Address) -> int: ...
def sendto(self, data: ReadableBuffer, flags_or_addr: socket._Address, addr: None = ...) -> int: ...
@overload
def sendto(self, data: ReadableBuffer, flags_or_addr: int | socket._Address, addr: socket._Address | None = ...) -> int: ...
def sendto(self, data: ReadableBuffer, flags_or_addr: int, addr: socket._Address) -> int: ...
def shutdown(self, how: int) -> None: ...
def read(self, len: int = ..., buffer: bytearray | None = ...) -> bytes: ...
def write(self, data: bytes) -> int: ...
def write(self, data: ReadableBuffer) -> int: ...
def do_handshake(self, block: bool = ...) -> None: ... # block is undocumented
@overload
def getpeercert(self, binary_form: Literal[False] = ...) -> _PeerCertRetDictType | None: ...
Expand Down Expand Up @@ -362,7 +365,7 @@ class SSLContext:
) -> None: ...
def load_default_certs(self, purpose: Purpose = ...) -> None: ...
def load_verify_locations(
self, cafile: StrOrBytesPath | None = ..., capath: StrOrBytesPath | None = ..., cadata: str | bytes | None = ...
self, cafile: StrOrBytesPath | None = ..., capath: StrOrBytesPath | None = ..., cadata: str | ReadableBuffer | None = ...
) -> None: ...
@overload
def get_ca_certs(self, binary_form: Literal[False] = ...) -> list[_PeerCertRetDictType]: ...
Expand Down Expand Up @@ -408,7 +411,7 @@ class SSLObject:
def session_reused(self) -> bool: ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def read(self, len: int = ..., buffer: bytearray | None = ...) -> bytes: ...
def write(self, data: bytes) -> int: ...
def write(self, data: ReadableBuffer) -> int: ...
@overload
def getpeercert(self, binary_form: Literal[False] = ...) -> _PeerCertRetDictType | None: ...
@overload
Expand All @@ -433,16 +436,21 @@ class MemoryBIO:
pending: int
eof: bool
def read(self, __size: int = ...) -> bytes: ...
def write(self, __buf: bytes) -> int: ...
def write(self, __buf: ReadableBuffer) -> int: ...
def write_eof(self) -> None: ...

@final
class SSLSession:
id: bytes
time: int
timeout: int
ticket_lifetime_hint: int
has_ticket: bool
@property
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these are read-only.

def has_ticket(self) -> bool: ...
@property
def id(self) -> bytes: ...
@property
def ticket_lifetime_hint(self) -> int: ...
@property
def time(self) -> int: ...
@property
def timeout(self) -> int: ...

class SSLErrorNumber(enum.IntEnum):
SSL_ERROR_EOF: int
Expand Down