Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
from __future__ import annotations
import ctypes
import struct
from collections.abc import Sequence
import dask
from distributed.utils import nbytes
BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard"))
msgpack_opts = {
("max_%s_len" % x): 2**31 - 1 for x in ["str", "bin", "array", "map", "ext"]
}
msgpack_opts["strict_map_key"] = False
msgpack_opts["raw"] = False
def frame_split_size(
frame: bytes | memoryview, n: int = BIG_BYTES_SHARD_SIZE
) -> list[memoryview]:
"""
Split a frame into a list of frames of maximum size
This helps us to avoid passing around very large bytestrings.
Examples
--------
>>> frame_split_size([b'12345', b'678'], n=3) # doctest: +SKIP
[b'123', b'45', b'678']
"""
n = n or BIG_BYTES_SHARD_SIZE
frame = memoryview(frame)
if frame.nbytes <= n:
return [frame]
nitems = frame.nbytes // frame.itemsize
items_per_shard = n // frame.itemsize
return [frame[i : i + items_per_shard] for i in range(0, nitems, items_per_shard)]
def pack_frames_prelude(frames):
nframes = len(frames)
nbytes_frames = map(nbytes, frames)
return struct.pack(f"Q{nframes}Q", nframes, *nbytes_frames)
def pack_frames(frames):
"""Pack frames into a byte-like object
This prepends length information to the front of the bytes-like object
See Also
--------
unpack_frames
"""
return b"".join([pack_frames_prelude(frames), *frames])
def unpack_frames(b):
"""Unpack bytes into a sequence of frames
This assumes that length information is at the front of the bytestring,
as performed by pack_frames
See Also
--------
pack_frames
"""
b = memoryview(b)
fmt = "Q"
fmt_size = struct.calcsize(fmt)
(n_frames,) = struct.unpack_from(fmt, b)
lengths = struct.unpack_from(f"{n_frames}{fmt}", b, fmt_size)
frames = []
start = fmt_size * (1 + n_frames)
for length in lengths:
end = start + length
frames.append(b[start:end])
start = end
return frames
def merge_memoryviews(mvs: Sequence[memoryview]) -> memoryview:
"""
Zero-copy "concatenate" a sequence of contiguous memoryviews.
Returns a new memoryview which slices into the underlying buffer
to extract out the portion equivalent to all of ``mvs`` being concatenated.
All the memoryviews must:
* Share the same underlying buffer (``.obj``)
* When merged, cover a continuous portion of that buffer with no gaps
* Have the same strides
* Be 1-dimensional
* Have the same format
* Be contiguous
Raises ValueError if these conditions are not met.
"""
if not mvs:
return memoryview(bytearray())
if len(mvs) == 1:
return mvs[0]
first = mvs[0]
if not isinstance(first, memoryview):
raise TypeError(f"Expected memoryview; got {type(first)}")
obj = first.obj
format = first.format
first_start_addr = 0
nbytes = 0
for i, mv in enumerate(mvs):
if not isinstance(mv, memoryview):
raise TypeError(f"{i}: expected memoryview; got {type(mv)}")
if mv.nbytes == 0:
continue
if mv.obj is not obj:
raise ValueError(
f"{i}: memoryview has different buffer: {mv.obj!r} vs {obj!r}"
)
if not mv.contiguous:
raise ValueError(f"{i}: memoryview non-contiguous")
if mv.ndim != 1:
raise ValueError(f"{i}: memoryview has {mv.ndim} dimensions, not 1")
if mv.format != format:
raise ValueError(f"{i}: inconsistent format: {mv.format} vs {format}")
start_addr = address_of_memoryview(mv)
if first_start_addr == 0:
first_start_addr = start_addr
else:
expected_addr = first_start_addr + nbytes
if start_addr != expected_addr:
raise ValueError(
f"memoryview {i} does not start where the previous ends. "
f"Expected {expected_addr:x}, starts {start_addr - expected_addr} byte(s) away."
)
nbytes += mv.nbytes
if nbytes == 0:
# all memoryviews were zero-length
assert len(first) == 0
return first
assert first_start_addr != 0, "Underlying buffer is null pointer?!"
base_mv = memoryview(obj).cast("B")
base_start_addr = address_of_memoryview(base_mv)
start_index = first_start_addr - base_start_addr
return base_mv[start_index : start_index + nbytes].cast(format)
one_byte_carr = ctypes.c_byte * 1
# ^ length and type don't matter, just use it to get the address of the first byte
def address_of_memoryview(mv: memoryview) -> int:
"""
Get the pointer to the first byte of a memoryview's data.
If the memoryview is read-only, NumPy must be installed.
"""
# NOTE: this method relies on pointer arithmetic to figure out
# where each memoryview starts within the underlying buffer.
# There's no direct API to get the address of a memoryview,
# so we use a trick through ctypes and the buffer protocol:
# https://mattgwwalker.wordpress.com/2020/10/15/address-of-a-buffer-in-python/
try:
carr = one_byte_carr.from_buffer(mv)
except TypeError:
# `mv` is read-only. `from_buffer` requires the buffer to be writeable.
# See https://bugs.python.org/issue11427 for discussion.
# This typically comes from `deserialize_bytes`, where `mv.obj` is an
# immutable bytestring.
pass
else:
return ctypes.addressof(carr)
try:
import numpy as np
except ImportError:
raise ValueError(
f"Cannot get address of read-only memoryview {mv} since NumPy is not installed."
)
# NumPy doesn't mind read-only buffers. We could just use this method
# for all cases, but it's nice to use the pure-Python method for the common
# case of writeable buffers (created by TCP comms, for example).
return np.asarray(mv).__array_interface__["data"][0]