Repository URL to install this package:
|
Version:
2022.10.0 ▾
|
from __future__ import annotations
import logging
import msgpack
import dask.config
from distributed.protocol import pickle
from distributed.protocol.compression import decompress, maybe_compress
from distributed.protocol.serialize import (
Pickled,
Serialize,
Serialized,
ToPickle,
merge_and_deserialize,
msgpack_decode_default,
msgpack_encode_default,
serialize_and_split,
)
from distributed.protocol.utils import msgpack_opts
from distributed.utils import ensure_memoryview
logger = logging.getLogger(__name__)
def dumps( # type: ignore[no-untyped-def]
msg, serializers=None, on_error="message", context=None, frame_split_size=None
) -> list:
"""Transform Python message to bytestream suitable for communication
Developer Notes
---------------
The approach here is to use `msgpack.dumps()` to serialize `msg` and
write the result to the first output frame. If `msgpack.dumps()`
encounters an object it cannot serialize like a NumPy array, it is handled
out-of-band by `_encode_default()` and appended to the output frame list.
"""
try:
if context and "compression" in context:
compress_opts = {"compression": context["compression"]}
else:
compress_opts = {}
def _inplace_compress_frames(header, frames):
compression = list(header.get("compression", [None] * len(frames)))
for i in range(len(frames)):
if compression[i] is None:
compression[i], frames[i] = maybe_compress(
frames[i], **compress_opts
)
header["compression"] = tuple(compression)
def create_serialized_sub_frames(obj: Serialized | Serialize) -> list:
if isinstance(obj, Serialized):
sub_header, sub_frames = obj.header, obj.frames
else:
sub_header, sub_frames = serialize_and_split(
obj,
serializers=serializers,
on_error=on_error,
context=context,
size=frame_split_size,
)
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
sub_header = msgpack.dumps(
sub_header, default=msgpack_encode_default, use_bin_type=True
)
return [sub_header] + sub_frames
def create_pickled_sub_frames(obj: Pickled | ToPickle) -> list:
if isinstance(obj, Pickled):
sub_header, sub_frames = obj.header, obj.frames
else:
sub_frames = []
sub_header = {
"pickled-obj": pickle.dumps(
obj.data,
# In to support len() and slicing, we convert `PickleBuffer`
# objects to memoryviews of bytes.
buffer_callback=lambda x: sub_frames.append(
ensure_memoryview(x)
),
)
}
_inplace_compress_frames(sub_header, sub_frames)
sub_header["num-sub-frames"] = len(sub_frames)
sub_header = msgpack.dumps(sub_header)
return [sub_header] + sub_frames
frames = [None]
def _encode_default(obj):
if isinstance(obj, (Serialize, Serialized)):
offset = len(frames)
frames.extend(create_serialized_sub_frames(obj))
return {"__Serialized__": offset}
elif isinstance(obj, (ToPickle, Pickled)):
offset = len(frames)
frames.extend(create_pickled_sub_frames(obj))
return {"__Pickled__": offset}
else:
return msgpack_encode_default(obj)
frames[0] = msgpack.dumps(msg, default=_encode_default, use_bin_type=True)
return frames
except Exception:
logger.critical("Failed to Serialize", exc_info=True)
raise
def loads(frames, deserialize=True, deserializers=None):
"""Transform bytestream back into Python value"""
allow_pickle = dask.config.get("distributed.scheduler.pickle")
try:
def _decode_default(obj):
offset = obj.get("__Serialized__", 0)
if offset > 0:
sub_header = msgpack.loads(
frames[offset],
object_hook=msgpack_decode_default,
use_list=False,
**msgpack_opts,
)
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if deserialize:
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)
return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
)
else:
return Serialized(sub_header, sub_frames)
offset = obj.get("__Pickled__", 0)
if offset > 0:
sub_header = msgpack.loads(frames[offset])
offset += 1
sub_frames = frames[offset : offset + sub_header["num-sub-frames"]]
if allow_pickle:
return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames)
else:
raise ValueError(
"Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`"
)
return msgpack_decode_default(obj)
return msgpack.loads(
frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts
)
except Exception:
logger.critical("Failed to deserialize", exc_info=True)
raise