diff --git a/parl/remote/communication.py b/parl/remote/communication.py index 467f21a8e..c5bd93adb 100644 --- a/parl/remote/communication.py +++ b/parl/remote/communication.py @@ -15,6 +15,7 @@ import cloudpickle import subprocess import os +import pickle5 from parl.utils import SerializeError, DeserializeError __all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return'] @@ -35,21 +36,10 @@ def _deserialize_serializable(obj): val.__dict__.update(obj["data"]) return val - context = pyarrow.default_serialization_context() + buffer_list = [] - # support deserialize in another environment - context.set_pickle(cloudpickle.dumps, cloudpickle.loads) - - # support serialize and deserialize custom class - context.register_type( - object, - "object", - custom_serializer=_serialize_serializable, - custom_deserializer=_deserialize_serializable) - - # if pyarrow is installed, parl will use pyarrow to serialize/deserialize objects. - serialize = lambda data: pyarrow.serialize(data, context=context).to_buffer() - deserialize = lambda data: pyarrow.deserialize(data, context=context) + serialize = lambda data: pickle5.dumps(data, protocol=5, buffer_callback=buffer_list.append) + deserialize = lambda data: pickle5.loads(data, buffers=buffer_list) else: # if pyarrow is not installed, parl will use cloudpickle to serialize/deserialize objects. serialize = lambda data: cloudpickle.dumps(data) diff --git a/parl/remote/job.py b/parl/remote/job.py index 4761139f2..5d0eb52ae 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -23,7 +23,7 @@ import argparse import cloudpickle -import pickle +import pickle5 import psutil import re import sys @@ -247,7 +247,7 @@ def wait_for_files(self, reply_socket, job_address): message = reply_socket.recv_multipart() tag = message[0] if tag == remote_constants.SEND_FILE_TAG: - pyfiles = pickle.loads(message[1]) + pyfiles = pickle5.loads(message[1]) envdir = tempfile.mkdtemp() for empty_subfolder in pyfiles['empty_subfolders']: diff --git a/parl/remote/monitor.py b/parl/remote/monitor.py index 295bcccfb..d27166e7f 100644 --- a/parl/remote/monitor.py +++ b/parl/remote/monitor.py @@ -13,7 +13,7 @@ # limitations under the License. import argparse -import pickle +import pickle5 import random import time import zmq @@ -60,7 +60,7 @@ def run(self): self.socket.send_multipart([b'[MONITOR]']) msg = self.socket.recv_multipart() - status = pickle.loads(msg[1]) + status = pickle5.loads(msg[1]) data = {'workers': [], 'clients': []} total_vacant_cpus = 0 total_used_cpus = 0 diff --git a/parl/remote/tests/log_server_test.py b/parl/remote/tests/log_server_test.py index bcc31f27c..0e3461cab 100644 --- a/parl/remote/tests/log_server_test.py +++ b/parl/remote/tests/log_server_test.py @@ -15,7 +15,7 @@ import json import multiprocessing import os -import pickle +import pickle5 import subprocess import sys import tempfile @@ -83,7 +83,7 @@ def test_log_server(self): # Get status status = master._get_status() - client_jobs = pickle.loads(status).get('client_jobs') + client_jobs = pickle5.loads(status).get('client_jobs') self.assertIsNotNone(client_jobs) # Get job id diff --git a/setup.py b/setup.py index 477b63a7a..bf1b5dce6 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ def find_version(*file_paths): "tensorboard<=2.11.0", "flask>=1.0.4", "click", + "pickle5", "psutil>=5.6.2", "flask_cors", "requests",