osm-labo/venv/lib/python3.12/site-packages/zmq/backend/cffi/socket.py
2025-08-31 12:29:41 +02:00

435 lines
14 KiB
Python

"""zmq Socket class"""
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import errno as errno_mod
import warnings
import zmq
from zmq.constants import SocketOption, _OptType
from zmq.error import ZMQError, _check_rc, _check_version
from ._cffi import ffi
from ._cffi import lib as C
from .message import Frame
from .utils import _retry_sys_call
nsp = new_sizet_pointer = lambda length: ffi.new('size_t*', length)
def new_uint64_pointer():
return ffi.new('uint64_t*'), nsp(ffi.sizeof('uint64_t'))
def new_int64_pointer():
return ffi.new('int64_t*'), nsp(ffi.sizeof('int64_t'))
def new_int_pointer():
return ffi.new('int*'), nsp(ffi.sizeof('int'))
def new_binary_data(length):
return ffi.new(f'char[{length:d}]'), nsp(ffi.sizeof('char') * length)
def value_uint64_pointer(val):
return ffi.new('uint64_t*', val), ffi.sizeof('uint64_t')
def value_int64_pointer(val):
return ffi.new('int64_t*', val), ffi.sizeof('int64_t')
def value_int_pointer(val):
return ffi.new('int*', val), ffi.sizeof('int')
def value_binary_data(val, length):
return ffi.new(f'char[{length + 1:d}]', val), ffi.sizeof('char') * length
_fd_size = ffi.sizeof('ZMQ_FD_T')
ZMQ_FD_64BIT = _fd_size == 8
IPC_PATH_MAX_LEN = C.get_ipc_path_max_len()
def new_pointer_from_opt(option, length=0):
opt_type = getattr(option, "_opt_type", _OptType.int)
if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd):
return new_int64_pointer()
elif opt_type == _OptType.bytes:
return new_binary_data(length)
else:
# default
return new_int_pointer()
def value_from_opt_pointer(option, opt_pointer, length=0):
try:
option = SocketOption(option)
except ValueError:
# unrecognized option,
# assume from the future,
# let EINVAL raise
opt_type = _OptType.int
else:
opt_type = option._opt_type
if opt_type == _OptType.bytes:
return ffi.buffer(opt_pointer, length)[:]
else:
return int(opt_pointer[0])
def initialize_opt_pointer(option, value, length=0):
opt_type = getattr(option, "_opt_type", _OptType.int)
if opt_type == _OptType.int64 or (ZMQ_FD_64BIT and opt_type == _OptType.fd):
return value_int64_pointer(value)
elif opt_type == _OptType.bytes:
return value_binary_data(value, length)
else:
return value_int_pointer(value)
class Socket:
context = None
socket_type = None
_zmq_socket = None
_closed = None
_ref = None
_shadow = False
_draft_poller = None
_draft_poller_ptr = None
copy_threshold = 0
def __init__(self, context=None, socket_type=None, shadow=0, copy_threshold=None):
if copy_threshold is None:
copy_threshold = zmq.COPY_THRESHOLD
self.copy_threshold = copy_threshold
self.context = context
self._draft_poller = self._draft_poller_ptr = None
if shadow:
self._zmq_socket = ffi.cast("void *", shadow)
self._shadow = True
else:
self._shadow = False
self._zmq_socket = C.zmq_socket(context._zmq_ctx, socket_type)
if self._zmq_socket == ffi.NULL:
raise ZMQError()
self._closed = False
@property
def underlying(self):
"""The address of the underlying libzmq socket"""
return int(ffi.cast('size_t', self._zmq_socket))
def _check_closed_deep(self):
"""thorough check of whether the socket has been closed,
even if by another entity (e.g. ctx.destroy).
Only used by the `closed` property.
returns True if closed, False otherwise
"""
if self._closed:
return True
try:
self.get(zmq.TYPE)
except ZMQError as e:
if e.errno == zmq.ENOTSOCK:
self._closed = True
return True
elif e.errno == zmq.ETERM:
pass
else:
raise
return False
@property
def closed(self):
return self._check_closed_deep()
def close(self, linger=None):
rc = 0
if not self._closed and hasattr(self, '_zmq_socket'):
if self._draft_poller_ptr is not None:
rc = C.zmq_poller_destroy(self._draft_poller_ptr)
self._draft_poller = self._draft_poller_ptr = None
if self._zmq_socket is not None:
if linger is not None:
self.set(zmq.LINGER, linger)
rc = C.zmq_close(self._zmq_socket)
self._closed = True
if rc < 0:
_check_rc(rc)
def bind(self, address):
if isinstance(address, str):
address_b = address.encode('utf8')
else:
address_b = address
if isinstance(address, bytes):
address = address_b.decode('utf8')
rc = C.zmq_bind(self._zmq_socket, address_b)
if rc < 0:
if IPC_PATH_MAX_LEN and C.zmq_errno() == errno_mod.ENAMETOOLONG:
path = address.split('://', 1)[-1]
msg = (
f'ipc path "{path}" is longer than {IPC_PATH_MAX_LEN} '
'characters (sizeof(sockaddr_un.sun_path)).'
)
raise ZMQError(C.zmq_errno(), msg=msg)
elif C.zmq_errno() == errno_mod.ENOENT:
path = address.split('://', 1)[-1]
msg = f'No such file or directory for ipc path "{path}".'
raise ZMQError(C.zmq_errno(), msg=msg)
else:
_check_rc(rc)
def unbind(self, address):
if isinstance(address, str):
address = address.encode('utf8')
rc = C.zmq_unbind(self._zmq_socket, address)
_check_rc(rc)
def connect(self, address):
if isinstance(address, str):
address = address.encode('utf8')
rc = C.zmq_connect(self._zmq_socket, address)
_check_rc(rc)
def disconnect(self, address):
if isinstance(address, str):
address = address.encode('utf8')
rc = C.zmq_disconnect(self._zmq_socket, address)
_check_rc(rc)
def set(self, option, value):
length = None
if isinstance(value, str):
raise TypeError("unicode not allowed, use bytes")
try:
option = SocketOption(option)
except ValueError:
# unrecognized option,
# assume from the future,
# let EINVAL raise
opt_type = _OptType.int
else:
opt_type = option._opt_type
if isinstance(value, bytes):
if opt_type != _OptType.bytes:
raise TypeError(f"not a bytes sockopt: {option}")
length = len(value)
c_value_pointer, c_sizet = initialize_opt_pointer(option, value, length)
_retry_sys_call(
C.zmq_setsockopt,
self._zmq_socket,
option,
ffi.cast('void*', c_value_pointer),
c_sizet,
)
def get(self, option):
try:
option = SocketOption(option)
except ValueError:
# unrecognized option,
# assume from the future,
# let EINVAL raise
opt_type = _OptType.int
else:
opt_type = option._opt_type
if option == zmq.FD and self._draft_poller is not None:
c_value_pointer, _ = new_pointer_from_opt(option)
C.zmq_poller_fd(self._draft_poller, ffi.cast('void*', c_value_pointer))
return int(c_value_pointer[0])
c_value_pointer, c_sizet_pointer = new_pointer_from_opt(option, length=255)
try:
_retry_sys_call(
C.zmq_getsockopt,
self._zmq_socket,
option,
c_value_pointer,
c_sizet_pointer,
)
except ZMQError as e:
if (
option == SocketOption.FD
and e.errno == zmq.Errno.EINVAL
and self.get(SocketOption.THREAD_SAFE)
):
_check_version((4, 3, 2), "draft socket FD support via zmq_poller_fd")
if not zmq.has('draft'):
raise RuntimeError("libzmq must be built with draft support")
warnings.warn(zmq.error.DraftFDWarning(), stacklevel=2)
# create a poller and retrieve its fd
self._draft_poller_ptr = ffi.new("void*[1]")
self._draft_poller_ptr[0] = self._draft_poller = C.zmq_poller_new()
if self._draft_poller == ffi.NULL:
# failed (why?), raise original error
self._draft_poller_ptr = self._draft_poller = None
raise
# register self with poller
rc = C.zmq_poller_add(
self._draft_poller,
self._zmq_socket,
ffi.NULL,
zmq.POLLIN | zmq.POLLOUT,
)
_check_rc(rc)
# use poller fd as proxy for ours
rc = C.zmq_poller_fd(
self._draft_poller, ffi.cast('void *', c_value_pointer)
)
_check_rc(rc)
return int(c_value_pointer[0])
else:
raise
sz = c_sizet_pointer[0]
v = value_from_opt_pointer(option, c_value_pointer, sz)
if (
option != zmq.SocketOption.ROUTING_ID
and opt_type == _OptType.bytes
and v.endswith(b'\0')
):
v = v[:-1]
return v
def _send_copy(self, buf, flags):
"""Send a copy of a bufferable"""
zmq_msg = ffi.new('zmq_msg_t*')
if not isinstance(buf, bytes):
# cast any bufferable data to bytes via memoryview
buf = memoryview(buf).tobytes()
c_message = ffi.new('char[]', buf)
rc = C.zmq_msg_init_size(zmq_msg, len(buf))
_check_rc(rc)
C.memcpy(C.zmq_msg_data(zmq_msg), c_message, len(buf))
_retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags)
rc2 = C.zmq_msg_close(zmq_msg)
_check_rc(rc2)
def _send_frame(self, frame, flags):
"""Send a Frame on this socket in a non-copy manner."""
# Always copy the Frame so the original message isn't garbage collected.
# This doesn't do a real copy, just a reference.
frame_copy = frame.fast_copy()
zmq_msg = frame_copy.zmq_msg
_retry_sys_call(C.zmq_msg_send, zmq_msg, self._zmq_socket, flags)
tracker = frame_copy.tracker
frame_copy.close()
return tracker
def send(self, data, flags=0, copy=False, track=False):
if isinstance(data, str):
raise TypeError("Message must be in bytes, not a unicode object")
if copy and not isinstance(data, Frame):
return self._send_copy(data, flags)
else:
close_frame = False
if isinstance(data, Frame):
if track and not data.tracker:
raise ValueError('Not a tracked message')
frame = data
else:
if self.copy_threshold:
buf = memoryview(data)
# always copy messages smaller than copy_threshold
if buf.nbytes < self.copy_threshold:
self._send_copy(buf, flags)
return zmq._FINISHED_TRACKER
frame = Frame(data, track=track, copy_threshold=self.copy_threshold)
close_frame = True
tracker = self._send_frame(frame, flags)
if close_frame:
frame.close()
return tracker
def recv(self, flags=0, copy=True, track=False):
if copy:
zmq_msg = ffi.new('zmq_msg_t*')
C.zmq_msg_init(zmq_msg)
else:
frame = zmq.Frame(track=track)
zmq_msg = frame.zmq_msg
try:
_retry_sys_call(C.zmq_msg_recv, zmq_msg, self._zmq_socket, flags)
except Exception:
if copy:
C.zmq_msg_close(zmq_msg)
raise
if not copy:
return frame
_buffer = ffi.buffer(C.zmq_msg_data(zmq_msg), C.zmq_msg_size(zmq_msg))
_bytes = _buffer[:]
rc = C.zmq_msg_close(zmq_msg)
_check_rc(rc)
return _bytes
def recv_into(self, buffer, /, *, nbytes: int = 0, flags: int = 0) -> int:
view = memoryview(buffer)
if not view.contiguous:
raise BufferError("Can only recv_into contiguous buffers")
if view.readonly:
raise BufferError("Cannot recv_into readonly buffer")
if nbytes < 0:
raise ValueError(f"{nbytes=} must be non-negative")
view_bytes = view.nbytes
if nbytes == 0:
nbytes = view_bytes
elif nbytes > view_bytes:
raise ValueError(f"{nbytes=} too big for memoryview of {view_bytes}B")
c_buf = ffi.from_buffer(view)
rc: int = _retry_sys_call(C.zmq_recv, self._zmq_socket, c_buf, nbytes, flags)
_check_rc(rc)
return rc
def monitor(self, addr, events=-1):
"""s.monitor(addr, flags)
Start publishing socket events on inproc.
See libzmq docs for zmq_monitor for details.
Note: requires libzmq >= 3.2
Parameters
----------
addr : str
The inproc url used for monitoring. Passing None as
the addr will cause an existing socket monitor to be
deregistered.
events : int [default: zmq.EVENT_ALL]
The zmq event bitmask for which events will be sent to the monitor.
"""
if events < 0:
events = zmq.EVENT_ALL
if addr is None:
addr = ffi.NULL
if isinstance(addr, str):
addr = addr.encode('utf8')
C.zmq_socket_monitor(self._zmq_socket, addr, events)
__all__ = ['Socket', 'IPC_PATH_MAX_LEN']