Repository URL to install this package:
|
Version:
3.0.0.dev0 ▾
|
import json
import logging
import types
from ray import cloudpickle as cloudpickle
from ray._common.utils import binary_to_hex, hex_to_binary
from ray.util.annotations import DeveloperAPI
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@DeveloperAPI
class TuneFunctionEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, types.FunctionType):
return self._to_cloudpickle(obj)
try:
return super(TuneFunctionEncoder, self).default(obj)
except Exception:
if log_once(f"tune_func_encode:{str(obj)}"):
logger.debug("Unable to encode. Falling back to cloudpickle.")
return self._to_cloudpickle(obj)
def _to_cloudpickle(self, obj):
return {
"_type": "CLOUDPICKLE_FALLBACK",
"value": binary_to_hex(cloudpickle.dumps(obj)),
}
@DeveloperAPI
class TuneFunctionDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, obj):
if obj.get("_type") == "CLOUDPICKLE_FALLBACK":
return self._from_cloudpickle(obj)
return obj
def _from_cloudpickle(self, obj):
return cloudpickle.loads(hex_to_binary(obj["value"]))