diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 1a61c542790dd6..130948357a1e8b 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -103,6 +103,9 @@ from . import rpc # noqa: F401 +from .checkpoint.save_state_dict import save_state_dict +from .checkpoint.load_state_dict import load_state_dict + __all__ = [ "io", "spawn", @@ -157,5 +160,7 @@ "Shard", "Replicate", "Partial", + "save_state_dict", + "load_state_dict", "shard_optimizer", ] diff --git a/python/paddle/distributed/checkpoint/__init__.py b/python/paddle/distributed/checkpoint/__init__.py new file mode 100644 index 00000000000000..da89b737adfb8d --- /dev/null +++ b/python/paddle/distributed/checkpoint/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .save_state_dict import save_state_dict +from .load_state_dict import load_state_dict + +__all__ = [ + "save_state_dict", + "load_state_dict", +] diff --git a/python/paddle/distributed/checkpoint/load_state_dict.py b/python/paddle/distributed/checkpoint/load_state_dict.py new file mode 100644 index 00000000000000..153c6764d70d60 --- /dev/null +++ b/python/paddle/distributed/checkpoint/load_state_dict.py @@ -0,0 +1,496 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from dataclasses import dataclass +from typing import Tuple + +import paddle +from paddle.distributed.communication.group import is_initialized +from paddle.distributed.fleet.utils.log_util import logger + +from .metadata import LocalTensorIndex, LocalTensorMetadata +from .utils import compute_local_shape_and_global_offset, flatten_state_dict + + +@dataclass(frozen=True) +class ReadItem: + local_tensor_index: LocalTensorIndex + rank: int + cur_offset: Tuple[int] + storage_offset: Tuple[int] + lengths: Tuple[int] + + +def get_rank_to_files(path, state_dict, process_group, use_dist): + accessible_files = os.listdir(path) + metadata_files = [ + file for file in accessible_files if file.endswith(".metadata") + ] + assert ( + len(metadata_files) > 0 + ), f"No metadata file found in the checkpoint directory:{path}." + # The neccesary files to be read + tensor_key_list = [] + necessary_files = [] + for metadata_file in metadata_files: + metadata = paddle.load(os.path.join(path, metadata_file)) + for local_tensor_index, file_name in metadata.storage_metadata.items(): + assert ( + local_tensor_index not in tensor_key_list + ), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata." + tensor_key_list.append(local_tensor_index.tensor_key) + if local_tensor_index.tensor_key in state_dict: + necessary_files.append(file_name) + necessary_data_files_set = set(necessary_files) + # allgather all accessible files + local_data_files = [ + file for file in accessible_files if file.endswith(".distcp") + ] + global_data_files = [] + if use_dist: + paddle.distributed.all_gather_object( + global_data_files, local_data_files, process_group + ) + else: + global_data_files.append(local_data_files) + tmp = [] + for files in global_data_files: + tmp += files + global_data_files_set = set(tmp) + logger.debug( + f"necessary_data_files_set:{necessary_data_files_set}, global_data_files_set:{global_data_files_set}" + ) + # check neccesary files in global_data_files + assert ( + global_data_files_set & necessary_data_files_set + == necessary_data_files_set + ), f"The checkpoint files are not complete. Please check the checkpoint directory:{path}.global_data_files_set:{global_data_files_set}, necessary_data_files_set:{necessary_data_files_set}" + missing_keys = set(state_dict.keys()) - set(tensor_key_list) + if len(missing_keys) > 0: + logger.warning( + f"Missing keys:{missing_keys}, check whether the checkpoint is complete." + ) + + rank_to_files = {} + for rank, local_files in enumerate(global_data_files): + if len(local_files) > 0: + local_files = [ + f for f in local_files if f in necessary_data_files_set + ] + rank_to_files[rank] = local_files + logger.debug(f"mapping rank_to_files:{rank_to_files}") + return rank_to_files + + +def get_local_load_files(rank_to_files): + """ + Load files in a load-balanced manner. + Example: + Case1: all ranks access the same data files + rank_to_files = {rank0:[0_0.distcp, 1_0.distcp, 2_0.distcp, 3_0.distcp], rank1:[0_0.distcp, 1_0.distcp, 2_0.distcp, 3_0.distcp]} + rank0 return [0_0.distcp, 1_0.distcp], rank1 return [2_0.distcp, 3_0.distcp] + Case2: all ranks access different data files but some overlapped + rank_to_files = {rank0:[0_0.distcp, 1_0.distcp, 2_0.distcp], rank1:[2_0.distcp, 3_0.distcp] + rank0 return [0_0.distcp, 1_0.distcp], rank1 return [2_0.distcp, 3_0.distcp] + Case3: all ranks access different data files and no overlapped + rank_to_files = {rank0:[0_0.distcp, 1_0.distcp], rank1:[2_0.distcp, 3_0.distcp] + rank0 return [0_0.distcp, 1_0.distcp], rank1 return [2_0.distcp, 3_0.distcp] + """ + file_to_ranks = {} + for rank, files in rank_to_files.items(): + for file in files: + if file not in file_to_ranks: + file_to_ranks[file] = [] + file_to_ranks[file].append(rank) + rank_to_not_read_files = copy.copy(rank_to_files) + rank_to_read_files = {rank: [] for rank in rank_to_not_read_files.keys()} + for file, ranks in file_to_ranks.items(): + if len(ranks) == 1: + rank = ranks[0] + rank_to_read_files[rank].append(file) + rank_to_not_read_files[rank].remove(file) + if len(rank_to_not_read_files[rank]) == 0: + rank_to_not_read_files.pop(rank) + + logger.debug( + f"rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}" + ) + + def get_least_read_files_ranks(rank_to_read_files): + nums = [ + (rank, len(files)) for rank, files in rank_to_read_files.items() + ] + nums = sorted(nums, key=lambda x: x[1]) + ranks = [rank for rank, num in nums if num == nums[0][1]] + return ranks + + def get_read_rank_file(rank_to_not_read_files, ranks): + if len(rank_to_not_read_files) == 0: + return (None, None) + nums = [ + (rank, len(files)) + for rank, files in rank_to_not_read_files.items() + if rank in ranks + ] + nums = sorted(nums, key=lambda x: x[1]) + rank = nums[0][0] + return (rank, rank_to_not_read_files[rank][0]) + + def update(rank_to_read_files, rank_to_not_read_files, rank_file): + rank, file = rank_file + if rank is None and file is None: + return + if rank not in rank_to_read_files: + rank_to_read_files[rank] = [] + rank_to_read_files[rank].append(file) + # update rank_to_not_read_files + file_to_ranks = {} + for r, files in rank_to_not_read_files.items(): + for f in files: + if f not in file_to_ranks: + file_to_ranks[f] = [] + file_to_ranks[f].append(r) + logger.debug(f"file_to_ranks:{file_to_ranks}") + if file in file_to_ranks: + for r in file_to_ranks[file]: + rank_to_not_read_files[r].remove(file) + if len(rank_to_not_read_files[r]) == 0: + rank_to_not_read_files.pop(r) + + while len(rank_to_not_read_files) > 0: + ranks = get_least_read_files_ranks(rank_to_read_files) + rank_file = get_read_rank_file(rank_to_not_read_files, ranks) + update(rank_to_read_files, rank_to_not_read_files, rank_file) + logger.debug( + f"update rank_to_read_files:{rank_to_read_files}, rank_to_not_read_files:{rank_to_not_read_files}, ranks:{ranks}, rank_file:{rank_file}" + ) + cur_rank = paddle.distributed.get_rank() + if cur_rank in rank_to_read_files: + return rank_to_read_files[cur_rank] + else: + logger.warning(f"rank:{cur_rank} does not need to load checkpoint") + return [] + + +def get_load_infos(path, local_load_files, process_group, use_dist): + load_info = {} + accessible_files = os.listdir(path) + metadata_files = [ + file for file in accessible_files if file.endswith(".metadata") + ] + assert ( + len(metadata_files) > 0 + ), "No metadata file found in the checkpoint directory:{path}." + for metadata_file in metadata_files: + metadata = paddle.load(os.path.join(path, metadata_file)) + for local_tensor_index, file_name in metadata.storage_metadata.items(): + if file_name in local_load_files: + load_info[local_tensor_index] = ( + paddle.distributed.get_rank(), + file_name, + ) + load_info_list = [] + if use_dist: + paddle.distributed.all_gather_object( + load_info_list, load_info, process_group + ) + else: + load_info_list.append(load_info) + load_infos = {} + for load_info in load_info_list: + for local_tensor_index, (rank, file_name) in load_info.items(): + assert local_tensor_index not in load_infos + load_infos[local_tensor_index] = (rank, file_name) + return load_infos + + +def compute_overlap( + cur_chunk_metadata: LocalTensorMetadata, + storage_local_tensor_metadata: LocalTensorMetadata, +): + cur_offsets = [] + storage_offsets = [] + lengths = [] + for cur_len, cur_offset, strorage_len, storage_offset in zip( + cur_chunk_metadata.local_shape, + cur_chunk_metadata.global_offset, + storage_local_tensor_metadata.local_shape, + storage_local_tensor_metadata.global_offset, + ): + begin_offset = max(cur_offset, storage_offset) + end_offset = min(cur_offset + cur_len, storage_offset + strorage_len) + if begin_offset == cur_offset: + cur_offsets.append(0) + storage_offsets.append(begin_offset - storage_offset) + elif begin_offset == storage_offset: + cur_offsets.append(begin_offset - cur_offset) + storage_offsets.append(0) + else: + raise ValueError( + f"Invalid begin_offset:{begin_offset}, cur_offset:{cur_offset}, storage_offset:{storage_offset}" + ) + lengths.append(end_offset - begin_offset) + assert ( + lengths[-1] >= 0 + ), f"Invalid length:{lengths[-1]}, end_offset:{end_offset}, begin_offset:{begin_offset}" + return cur_offsets, storage_offsets, lengths + + +def not_overlap( + cur_chunk_metadata: LocalTensorMetadata, + storage_local_tensor_metadata: LocalTensorMetadata, +): + for cur_len, cur_offset, strorage_len, storage_offset in zip( + cur_chunk_metadata.local_shape, + cur_chunk_metadata.global_offset, + storage_local_tensor_metadata.local_shape, + storage_local_tensor_metadata.global_offset, + ): + if ( + cur_offset >= (storage_offset + strorage_len) + or (cur_offset + cur_len) <= storage_offset + ): + return True + return False + + +def get_read_items(path, state_dict, process_group, use_dist): + accessible_files = os.listdir(path) + metadata_files = [ + file for file in accessible_files if file.endswith(".metadata") + ] + assert ( + len(metadata_files) > 0 + ), "No metadata file found in the checkpoint directory:{path}." + storage_state_dict_metadata = {} + for metadata_file in metadata_files: + metadata = paddle.load(os.path.join(path, metadata_file)) + for ( + tensor_key, + local_tensor_metadata, + ) in metadata.state_dict_metadata.items(): + if tensor_key not in storage_state_dict_metadata: + storage_state_dict_metadata[tensor_key] = [] + storage_state_dict_metadata[tensor_key] += local_tensor_metadata + read_items = [] + logger.debug(f"storage_state_dict_metadata:{storage_state_dict_metadata}") + for tensor_key, val in state_dict.items(): + if isinstance(val, paddle.Tensor): + if val.is_dist(): + ( + local_shape, + global_offset, + ) = compute_local_shape_and_global_offset( + val.shape, + val.dist_attr.process_mesh, + val.dist_attr.dims_mapping, + ) + else: + local_shape = val.shape + global_offset = [0] * len(val.shape) + if not local_shape or not global_offset: + continue + cur_chunk_metadata = LocalTensorMetadata(global_offset, local_shape) + assert ( + tensor_key in storage_state_dict_metadata + ), f"tensor_key:{tensor_key} not found in storage_state_dict_metadata:{storage_state_dict_metadata}." + for storage_local_tensor_metadata in storage_state_dict_metadata[ + tensor_key + ]: + if not_overlap( + cur_chunk_metadata, storage_local_tensor_metadata + ): + continue + cur_offsets, storage_offsets, lengths = compute_overlap( + cur_chunk_metadata, storage_local_tensor_metadata + ) + storage_local_tensor_index = LocalTensorIndex( + tensor_key, + tuple(storage_local_tensor_metadata.global_offset), + ) + read_items.append( + ReadItem( + storage_local_tensor_index, + paddle.distributed.get_rank(), + tuple(cur_offsets), + tuple(storage_offsets), + tuple(lengths), + ) + ) + else: + raise ValueError( + f"Only support paddle.Tensor., val type:{type(val)}" + ) + global_read_items = [] + tmp = [] + if use_dist: + paddle.distributed.all_gather_object(tmp, read_items, process_group) + else: + tmp.append(read_items) + for items in tmp: + for item in items: + global_read_items.append(item) + return global_read_items + + +def load_state_dict( + state_dict, + path, + process_group=None, + coordinator_rank=0, +) -> None: + """ + Load the state_dict inplace from a checkpoint path. + Args: + state_dict(Dict[str, paddle.Tensor]): The state_dict to load. It will be modified inplace after loading. + path(str): The directory to load checkpoint files. + process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. + coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default. + Example: + .. code-block:: python + >>> # doctest: +SKIP('Load state dict.') + >>> import paddle + >>> import paddle.distributed as dist + >>> ckpt_path = "./checkpoint" + >>> w1 = paddle.arange(32).reshape([4, 8]) + >>> mesh = dist.ProcessMesh([0, 1]) + >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) + >>> state_dict = {"w1": sharded_w1} + >>> dist.save_state_dict(state_dict, ckpt_path) + >>> w1_to_load = paddle.zeros_like(w1) + >>> sharded_w1_to_load = dist.shard_tensor(w1, mesh, [dist.Replicate()]) + >>> state_dict_to_load = {"w1": sharded_w1_to_load} + >>> dist.load_state_dict(state_dict_to_load, ckpt_path) + >>> print(f"state_dict_to_load:{state_dict_to_load}") + state_dict_to_load:{'w1': Tensor(shape=[4, 8], dtype=int64, place=Place(gpu:0), stop_gradient=True, dist_attr={process_mesh: {shape: [2], process_ids: [0,1], dim_names: [d0]}, dims_mappings: [-1,-1], batch_dim: 0, dynamic_dims: [0,0], annotated: [dims_mapping: 1,process_mesh: 1], partial: [].}, GlobalDenseTensor= + [[0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11, 12, 13, 14, 15], + [16, 17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30, 31]])} + >>> # doctest: -SKIP + """ + assert isinstance( + state_dict, dict + ), "The state_dict should be a dictionary." + state_dict = flatten_state_dict(state_dict) + if len(state_dict) > 0: + for val in state_dict.values(): + assert isinstance( + val, paddle.Tensor + ), "Only support dygraph Tensor now, support static DistributedTensor later" + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist and process_group is None and not is_initialized(): + # Init the default global process group + paddle.distributed.init_parallel_env() + + rank_to_files = get_rank_to_files(path, state_dict, process_group, use_dist) + local_load_files = get_local_load_files(rank_to_files) + # load_infos: {LocalTensorIndex: (rank, file_name)}, which local tensor located in which file, and the file is load in which rank. + load_infos = get_load_infos(path, local_load_files, process_group, use_dist) + # read_items: [ReadItem(local_tensor_index, rank, cur_offsets, storage_offsets, lengths)], + # slice the storage local tensor in (storage_offsets, lengths) to assign the current tensor in (cur_offsets, lengths) in rank. + read_items = get_read_items(path, state_dict, process_group, use_dist) + storage_file_to_state_dict = {} + logger.debug( + f"before load, state_dict:{state_dict},\n load_infos:{load_infos},\n read_items:{read_items}" + ) + for item in read_items: + assert ( + item.local_tensor_index in load_infos + ), f"item:{item}, load_infos:{load_infos}" + src_rank, file_name = load_infos[item.local_tensor_index] + storage_chunk_tensor = None + cur_chunk_tensor = None + # The src rank need to load the state_dict. + if src_rank == paddle.distributed.get_rank(): + if file_name not in storage_file_to_state_dict: + # The value in state_dict is not distributed tensor but a normal tensor. + storage_file_to_state_dict[file_name] = paddle.load( + os.path.join(path, file_name) + ) + storage_state_dict = storage_file_to_state_dict[file_name] + assert item.local_tensor_index.tensor_key in storage_state_dict + storage_local_tensor = storage_state_dict[ + item.local_tensor_index.tensor_key + ] + storage_offsets = item.storage_offset + storage_lengths = item.lengths + storage_ends = [ + storage_offset + storage_length + for storage_offset, storage_length in zip( + storage_offsets, storage_lengths + ) + ] + # The storage_chunk_tensor and storage_local_tensor share the same memory. + storage_chunk_tensor = paddle.slice( + storage_local_tensor, + list(range(len(storage_lengths))), + storage_offsets, + storage_ends, + ) + # The read item rank need to be assigned + if item.rank == paddle.distributed.get_rank(): + assert ( + item.local_tensor_index.tensor_key in state_dict + ), f"item:{item}, state_dict:{state_dict}" + cur_local_tensor = ( + state_dict[item.local_tensor_index.tensor_key]._local_value() + if use_dist + else state_dict[item.local_tensor_index.tensor_key] + ) + cur_offsets = item.cur_offset + cur_lengths = item.lengths + cur_ends = [ + cur_offset + cur_length + for cur_offset, cur_length in zip(cur_offsets, cur_lengths) + ] + # The cur_chunk_tensor and cur_local_tensor share the same memory. + cur_chunk_tensor = paddle.slice( + cur_local_tensor, + list(range(len(cur_lengths))), + cur_offsets, + cur_ends, + ) + else: + cur_chunk_tensor = paddle.zeros( + item.lengths, + dtype=state_dict[item.local_tensor_index.tensor_key].dtype, + ) + + if src_rank == item.rank: + # assign value locally + paddle.assign(storage_chunk_tensor, cur_chunk_tensor) + else: + # assign value remotely + if src_rank == paddle.distributed.get_rank(): + paddle.distributed.broadcast( + storage_chunk_tensor, src=src_rank, group=process_group + ) + else: + paddle.distributed.broadcast( + cur_chunk_tensor, src=src_rank, group=process_group + ) + + local_state_dict = ( + {k: v._local_value() for k, v in state_dict.items()} + if use_dist + else state_dict + ) + logger.debug( + f"after load, local_state_dict:{local_state_dict} \n state_dict:{state_dict}" + ) diff --git a/python/paddle/distributed/checkpoint/metadata.py b/python/paddle/distributed/checkpoint/metadata.py new file mode 100644 index 00000000000000..4eb5d559a9c0c4 --- /dev/null +++ b/python/paddle/distributed/checkpoint/metadata.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List, Tuple + + +@dataclass +class LocalTensorMetadata: + """ + The location of a local tensor in the global tensor. + """ + + global_offset: Tuple[int] + local_shape: Tuple[int] + + +@dataclass(frozen=True) +class LocalTensorIndex: + """ + The identifier of a local tensor. + """ + + tensor_key: str + global_offset: Tuple[int] + + +@dataclass +class Metadata: + state_dict_metadata: Dict[str, List[LocalTensorMetadata]] = None + storage_metadata: Dict[LocalTensorIndex, str] = None diff --git a/python/paddle/distributed/checkpoint/save_state_dict.py b/python/paddle/distributed/checkpoint/save_state_dict.py new file mode 100644 index 00000000000000..4b7f3665d86da2 --- /dev/null +++ b/python/paddle/distributed/checkpoint/save_state_dict.py @@ -0,0 +1,186 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List + +import paddle +from paddle.distributed.communication.group import is_initialized +from paddle.distributed.fleet.utils.log_util import logger + +from .metadata import LocalTensorIndex, LocalTensorMetadata, Metadata +from .utils import compute_local_shape_and_global_offset, flatten_state_dict + + +def check_state_dict(state_dict, process_group): + local_keys = list(state_dict.keys()) + gloabl_keys = [] + paddle.distributed.all_gather_object(gloabl_keys, local_keys, process_group) + for keys in gloabl_keys[1:]: + assert ( + keys == gloabl_keys[0] + ), f"keys:{keys} != first_keys: {gloabl_keys[0]}" + + +def check_file_name(file_name, process_group): + all_unique_id = [] + unique_id = int(file_name.split(".")[0].split("_")[1]) + paddle.distributed.all_gather_object( + all_unique_id, unique_id, process_group + ) + for id in all_unique_id[1:]: + assert ( + id == all_unique_id[0] + ), f"id:{id} != all_unique_id[0]:{file_name}" + + +def merge_state_dict_metadata(global_state_dict): + assert isinstance( + global_state_dict, List + ), "The global_state_dict should be a list." + out = {} + for state_dict in global_state_dict: + for key, val in state_dict.items(): + if key in out: + if val in out[key]: + continue + out[key].append(val) + else: + out[key] = [val] + return out + + +def dedup_storage_metadata(global_state_dict): + out = {} + for state_dict in global_state_dict: + for key, val in state_dict.items(): + if key in out: + continue + out[key] = val + return out + + +def save_state_dict( + state_dict, + path, + process_group=None, + coordinator_rank=0, +) -> None: + """ + Save the state_dict of model to path. + + Args: + state_dict(Dict[str, paddle.Tensor]): The state_dict to save. + path(str): The directory to save state_dict. + process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards. + coordinator_rank(int): The rank used to save non distributed values. Rank0 is used by default. + + Examples: + .. code-block:: python + >>> # doctest: +SKIP('Save state dict.') + >>> import paddle + >>> import paddle.distributed as dist + >>> w1 = paddle.arange(32).reshape([4, 8]) + >>> mesh = dist.ProcessMesh([0, 1]) + >>> sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0), dist.Replicate()]) + >>> state_dict = {"w1": sharded_w1} + >>> dist.save_state_dict(state_dict, "./checkpoint") + >>> # doctest: -SKIP + + """ + assert isinstance( + state_dict, dict + ), "The state_dict should be a dictionary." + state_dict = flatten_state_dict(state_dict) + if len(state_dict) > 0: + for val in state_dict.values(): + assert isinstance( + val, paddle.Tensor + ), "Only support dygraph Tensor now, support static DistributedTensor later" + + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + + use_dist = True if paddle.distributed.get_world_size() > 1 else False + + if use_dist and process_group is None and not is_initialized(): + # Init the default global process group + paddle.distributed.init_parallel_env() + + unique_id = 0 + file_name = "" + while True: + file_name = f"{paddle.distributed.get_rank()}_{unique_id}.distcp" + if not os.path.exists(os.path.join(path, file_name)): + break + unique_id += 1 + logger.debug(f"file_name:{file_name}") + if use_dist: + check_file_name(file_name, process_group) + # the parameter_name and order in state_dict should be the same + check_state_dict(state_dict, process_group) + metadata = Metadata() + local_state_dict = {} + local_state_dict_metadata = {} + local_storage_metadata = {} + for key, val in state_dict.items(): + if isinstance(val, paddle.Tensor): + # Case1: not initialized means this tensor is placed in another mesh which do not contain this rank + if not val._is_initialized(): + continue + if val.is_dist(): + ( + local_shape, + global_offset, + ) = compute_local_shape_and_global_offset( + val.shape, + val.dist_attr.process_mesh, + val.dist_attr.dims_mapping, + ) + if not local_shape or not global_offset: + continue + local_tensor = val._local_value() + else: + global_offset = [0] * len(val.shape) + local_shape = val.shape + local_tensor = val + local_state_dict[key] = local_tensor + local_state_dict_metadata[key] = LocalTensorMetadata( + global_offset, local_shape + ) + local_storage_metadata[ + LocalTensorIndex(key, tuple(global_offset)) + ] = file_name + global_state_dict_metadata = [] + global_storage_metadata = [] + if use_dist: + paddle.distributed.all_gather_object( + global_state_dict_metadata, local_state_dict_metadata, process_group + ) + paddle.distributed.all_gather_object( + global_storage_metadata, local_storage_metadata, process_group + ) + else: + global_state_dict_metadata.append(local_state_dict_metadata) + global_storage_metadata.append(local_storage_metadata) + + metadata.state_dict_metadata = merge_state_dict_metadata( + global_state_dict_metadata + ) + metadata.storage_metadata = dedup_storage_metadata(global_storage_metadata) + if coordinator_rank == paddle.distributed.get_rank(): + logger.debug(f"metadata:{metadata}") + paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) + logger.debug(f"local_state_dict:{local_state_dict}") + paddle.save(local_state_dict, os.path.join(path, file_name)) diff --git a/python/paddle/distributed/checkpoint/utils.py b/python/paddle/distributed/checkpoint/utils.py new file mode 100644 index 00000000000000..32b95198d135ae --- /dev/null +++ b/python/paddle/distributed/checkpoint/utils.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import List, Tuple, Union + +import numpy as np + +import paddle +from paddle.framework import core + + +def get_coordinator(mesh: Union[np.array, List[List[int]]], rank: int): + mesh = paddle.to_tensor(mesh) + rand_coordinator = (mesh == rank).nonzero() + assert rand_coordinator.shape[0] in ( + 0, + 1, + ), f"rand_coordinator.shape: {rand_coordinator.shape}" + return ( + rand_coordinator[0].tolist() if rand_coordinator.shape[0] > 0 else None + ) + + +def compute_local_shape_and_global_offset( + global_shape: List[int], + process_mesh: core.ProcessMesh, + dims_mapping: List[int], +) -> Tuple[Tuple[int], Tuple[int]]: + mesh = np.array(process_mesh.process_ids).reshape(process_mesh.shape) + # deal with cross mesh case + if paddle.distributed.get_rank() not in mesh: + return ((), ()) + rank_coordinator = get_coordinator(mesh, paddle.distributed.get_rank()) + local_shape = copy.copy(global_shape) + global_offset = [0 for _ in global_shape] + for i, dim in enumerate(dims_mapping): + if dim == -1: + continue + else: + assert ( + global_shape[i] % process_mesh.shape[dim] == 0 + ), f"i:{i}, global_shape[i]:{global_shape[i]}, process_mesh.shape[dim]:{process_mesh.shape[dim]}" + local_shape[i] = global_shape[i] // process_mesh.shape[dim] + chunk_idx = rank_coordinator[dim] + global_offset[i] = chunk_idx * local_shape[i] + + return tuple(local_shape), tuple(global_offset) + + +def flatten_state_dict(state_dict): + # TODO, {"model": {"w0": xxx}} -> {model.w0: xxx} + return state_dict diff --git a/python/setup.py.in b/python/setup.py.in index ad87011315b592..9e111270cec552 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -371,6 +371,7 @@ packages=['paddle', 'paddle.dataset', 'paddle.reader', 'paddle.distributed', + 'paddle.distributed.checkpoint', 'paddle.distributed.communication', 'paddle.distributed.communication.stream', 'paddle.distributed.metric', diff --git a/setup.py b/setup.py index 05e7879607aade..7ae0f7c24b57b9 100644 --- a/setup.py +++ b/setup.py @@ -1381,6 +1381,7 @@ def get_setup_parameters(): 'paddle.dataset', 'paddle.reader', 'paddle.distributed', + 'paddle.distributed.checkpoint', 'paddle.distributed.communication', 'paddle.distributed.communication.stream', 'paddle.distributed.metric', diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index 257f716dfa192b..a1759193941f20 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -12,3 +12,10 @@ if((WITH_GPU) AND (LINUX)) set_tests_properties(test_semi_auto_parallel_hybrid_strategy PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") endif() +if((WITH_GPU) AND (LINUX)) + py_test_modules( + test_save_load_state_dict MODULES test_save_load_state_dict ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_save_load_state_dict + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") +endif() diff --git a/test/auto_parallel/hybrid_strategy/load_state_dict.py b/test/auto_parallel/hybrid_strategy/load_state_dict.py new file mode 100644 index 00000000000000..c500853324e713 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/load_state_dict.py @@ -0,0 +1,159 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +from auto_parallel.hybrid_strategy.save_state_dict import ( + get_global_state_dict, +) + +import paddle +import paddle.distributed as dist +from paddle.distributed import load_state_dict +from paddle.distributed.checkpoint.utils import ( + compute_local_shape_and_global_offset, + get_coordinator, +) + + +class TestLoadStateDict: + def __init__(self): + self._ckpt_path = os.getenv("ckpt_path") + + def test_load_state_dict_with_one_device(self): + global_state_dict = get_global_state_dict() + saved_w1, saved_w2 = list(global_state_dict.values()) + w1 = paddle.zeros_like(saved_w1) + w2 = paddle.zeros_like(saved_w2) + state_dict = dict(zip(list(global_state_dict.keys()), [w1, w2])) + load_state_dict(state_dict, self._ckpt_path) + # check + expect_w1 = saved_w1 + expect_w2 = saved_w2 + expect_state_dict = dict( + zip(list(global_state_dict.keys()), [expect_w1, expect_w2]) + ) + for k, v in state_dict.items(): + assert k in expect_state_dict, k + self.check_tensor_eq(v, expect_state_dict[k]) + + def test_load_state_dict_with_four_devices(self): + global_state_dict = get_global_state_dict() + saved_w1, saved_w2 = list(global_state_dict.values()) + w1 = paddle.zeros_like(saved_w1) + w2 = paddle.zeros_like(saved_w2) + mesh = dist.ProcessMesh([0, 1, 2, 3]) + sharded_w1 = dist.shard_tensor( + w1, mesh, [dist.Shard(0), dist.Replicate()] + ) + sharded_w2 = dist.shard_tensor( + w2, mesh, [dist.Replicate(), dist.Replicate()] + ) + state_dict = dict( + zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) + ) + load_state_dict(state_dict, self._ckpt_path) + # check + cur_rank = paddle.distributed.get_rank() + expect_w1 = saved_w1.split(4, axis=0)[cur_rank] + expect_w2 = sharded_w2 + expect_state_dict = dict( + zip(list(global_state_dict.keys()), [expect_w1, expect_w2]) + ) + for k, v in state_dict.items(): + assert k in expect_state_dict, k + self.check_tensor_eq(v._local_value(), expect_state_dict[k]) + + def test_load_state_dict_with_two_devices(self): + global_state_dict = get_global_state_dict() + saved_w1, saved_w2 = list(global_state_dict.values()) + w1 = paddle.zeros_like(saved_w1) + w2 = paddle.zeros_like(saved_w2) + mesh = dist.ProcessMesh([0, 1]) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(0)]) + sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(1)]) + state_dict = dict( + zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) + ) + load_state_dict(state_dict, self._ckpt_path) + # check + cur_rank = paddle.distributed.get_rank() + expect_w1 = saved_w1.split(2, axis=0)[cur_rank] + expect_w2 = saved_w2.split(2, axis=1)[cur_rank] + expect_state_dict = dict( + zip(list(global_state_dict.keys()), [expect_w1, expect_w2]) + ) + for k, v in state_dict.items(): + assert k in expect_state_dict, k + self.check_tensor_eq(v._local_value(), expect_state_dict[k]) + + def test_load_state_dict_with_eight_devices(self): + global_state_dict = get_global_state_dict() + saved_w1, saved_w2 = list(global_state_dict.values()) + w1 = paddle.zeros_like(saved_w1) + w2 = paddle.zeros_like(saved_w2) + mesh = dist.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) + sharded_w1 = dist.shard_tensor(w1, mesh, [dist.Shard(1), dist.Shard(0)]) + sharded_w2 = dist.shard_tensor(w2, mesh, [dist.Shard(0)]) + state_dict = dict( + zip(list(global_state_dict.keys()), [sharded_w1, sharded_w2]) + ) + load_state_dict(state_dict, self._ckpt_path) + # check + cur_rank = paddle.distributed.get_rank() + local_shape, global_offset = compute_local_shape_and_global_offset( + sharded_w1.shape, + sharded_w1.dist_attr.process_mesh, + sharded_w1.dist_attr.dims_mapping, + ) + end_offset = [ + offset + length + for offset, length in zip(global_offset, local_shape) + ] + expect_w1 = paddle.slice( + saved_w1, axes=[0, 1], starts=global_offset, ends=end_offset + ) + cur_coordinator = get_coordinator( + np.array([[0, 1, 2, 3], [4, 5, 6, 7]]), cur_rank + ) + expect_w2 = saved_w2.split(2, axis=0)[cur_coordinator[0]] + expect_state_dict = dict( + zip(list(global_state_dict.keys()), [expect_w1, expect_w2]) + ) + for k, v in state_dict.items(): + assert k in expect_state_dict, k + self.check_tensor_eq(v._local_value(), expect_state_dict[k]) + + def check_tensor_eq(self, a, b, verbose=True): + np1 = a.astype("float32").numpy() + np2 = b.astype("float32").numpy() + np.testing.assert_equal(np1, np2, verbose=verbose) + + def run_test_case(self): + device_num = int(os.getenv("device_num")) + if device_num == 1: + self.test_load_state_dict_with_one_device() + elif device_num == 2: + self.test_load_state_dict_with_two_devices() + elif device_num == 4: + self.test_load_state_dict_with_four_devices() + elif device_num == 8: + self.test_load_state_dict_with_eight_devices() + else: + raise ValueError("device_num should be 1, 2, 4 or 8") + + +if __name__ == '__main__': + TestLoadStateDict().run_test_case() diff --git a/test/auto_parallel/hybrid_strategy/save_state_dict.py b/test/auto_parallel/hybrid_strategy/save_state_dict.py new file mode 100644 index 00000000000000..0fd2f5d7049dbf --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/save_state_dict.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import paddle.distributed as dist +from paddle.distributed import save_state_dict + + +def get_global_state_dict(): + w1 = paddle.arange(32).reshape([4, 8]) + w2 = paddle.arange(32, 36).reshape([2, 2]) + return {"w1": w1, "w2": w2} + + +class TestSaveStateDict: + def __init__(self): + self._ckpt_path = os.getenv("ckpt_path") + + def test_save_state_dict_with_one_device(self): + global_state_dict = get_global_state_dict() + keys = list(global_state_dict.keys()) + w1, w2 = list(global_state_dict.values()) + state_dict = dict(zip(keys, [w1, w2])) + save_state_dict(state_dict, self._ckpt_path) + + def test_save_state_dict_with_four_devices(self): + global_state_dict = get_global_state_dict() + keys = list(global_state_dict.keys()) + w1, w2 = list(global_state_dict.values()) + mesh = dist.ProcessMesh([0, 1]) + mesh2 = dist.ProcessMesh([2, 3]) + sharded_w1 = dist.shard_tensor( + w1, mesh, [dist.Shard(0), dist.Replicate()] + ) + sharded_w2 = dist.shard_tensor( + w2, mesh2, [dist.Shard(0), dist.Replicate()] + ) + state_dict = dict(zip(keys, [sharded_w1, sharded_w2])) + save_state_dict(state_dict, self._ckpt_path) + + def run_test_case(self): + device_num = int(os.getenv("device_num")) + if device_num == 1: + self.test_save_state_dict_with_one_device() + elif device_num == 4: + self.test_save_state_dict_with_four_devices() + + +if __name__ == "__main__": + TestSaveStateDict().run_test_case() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py index 81b55cf266a08a..a3c2938a7370f9 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp.py @@ -30,6 +30,7 @@ def __init__(self): self._dtype = os.getenv("dtype") self._backend = os.getenv("backend") self._seed = eval(os.getenv("seed")) + self._ckpt_path = os.getenv("ckpt_path") self._mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) paddle.set_device(self._backend) @@ -55,6 +56,19 @@ def test_dp_mp_demo_net(self): self.check_tensor_eq(param, param_base) self.check_tensor_eq(param.grad, param_base.grad) + # save load + state_dict = model.state_dict() + local_state_dict = {} + for k, v in state_dict.items(): + local_state_dict[k] = v._local_value().clone() + paddle.distributed.save_state_dict(state_dict, self._ckpt_path) + for k, v in state_dict.items(): + v._local_value().add_(paddle.ones_like(v._local_value())) + paddle.distributed.load_state_dict(state_dict, self._ckpt_path) + for k, v in state_dict.items(): + assert k in local_state_dict, k + self.check_tensor_eq(v._local_value(), local_state_dict[k]) + def run_test_case(self): self.test_dp_mp_demo_net() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py index ecac26ee46d86d..ddbc66e080b2b7 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_simple_net_dp_mp_pp.py @@ -31,6 +31,7 @@ def __init__(self): self._dtype = os.getenv("dtype") self._backend = os.getenv("backend") self._seed = eval(os.getenv("seed")) + self._ckpt_path = os.getenv("ckpt_path") self._mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) self._pp_mesh0 = dist.ProcessMesh( [[0, 1], [2, 3]], dim_names=["x", "y"] @@ -103,6 +104,22 @@ def test_dp_mp_pp_demo_net(self): self.dp_mp_pp_parameters[3], self.base_parameters[3] ) + # save load + state_dict = model.state_dict() + local_state_dict = {} + for k, v in state_dict.items(): + local_state_dict[k] = ( + v._local_value().clone() if v._is_initialized() else None + ) + paddle.distributed.save_state_dict(state_dict, self._ckpt_path) + for k, v in state_dict.items(): + v._local_value().add_(paddle.ones_like(v._local_value())) + paddle.distributed.load_state_dict(state_dict, self._ckpt_path) + for k, v in state_dict.items(): + assert k in local_state_dict, k + if v._is_initialized(): + self.check_tensor_eq(v._local_value(), local_state_dict[k]) + def run_test_case(self): self.test_dp_mp_pp_demo_net() diff --git a/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py new file mode 100644 index 00000000000000..a0b64a374d6274 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/test_save_load_state_dict.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import collective.test_communication_api_base as test_base + + +class TestSaveLoadStateDict(test_base.CommunicationTestDistBase): + def setUp(self): + self._default_envs = {} + self._changeable_envs = {"device_num": ["1", "2", "4", "8"]} + + def test_save_load_state_dict(self): + # save with 1 device + ckpt_path = tempfile.TemporaryDirectory() + super().setUp(num_of_devices=1, timeout=120, nnode=1) + self.run_test_case( + "save_state_dict.py", + user_defined_envs={"device_num": "1", "ckpt_path": ckpt_path.name}, + ) + + # load with 1, 2, 4, 8 devices + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + envs["ckpt_path"] = ckpt_path.name + super().setUp( + num_of_devices=int(envs["device_num"]), + timeout=180, + nnode=1, + ) + self.run_test_case( + "load_state_dict.py", + user_defined_envs=envs, + ) + ckpt_path.cleanup() + + # save with 4 devices + ckpt_path = tempfile.TemporaryDirectory() + super().setUp(num_of_devices=4, timeout=120, nnode=1) + self.run_test_case( + "save_state_dict.py", + user_defined_envs={"device_num": "4", "ckpt_path": ckpt_path.name}, + ) + # load with 1, 2, 4, 8 devices + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + envs["ckpt_path"] = ckpt_path.name + super().setUp( + num_of_devices=int(envs["device_num"]), + timeout=180, + nnode=1, + ) + self.run_test_case( + "load_state_dict.py", + user_defined_envs=envs, + ) + ckpt_path.cleanup() + + +if __name__ == '__main__': + unittest.main() diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py index 4da980bb466cfa..21da5d0a694425 100644 --- a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_hybrid_strategy.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest import collective.test_communication_api_base as test_base @@ -31,10 +32,13 @@ def test_simple_net_bybrid_strategy(self): self._default_envs, self._changeable_envs ) for envs in envs_list: + ckpt_path = tempfile.TemporaryDirectory() + envs["ckpt_path"] = ckpt_path.name self.run_test_case( "semi_auto_parallel_simple_net_dp_mp.py", user_defined_envs=envs, ) + ckpt_path.cleanup() class TestSemiAutoParallelHybridStrategy(test_base.CommunicationTestDistBase): @@ -55,10 +59,13 @@ def test_simple_net_bybrid_strategy(self): self._default_envs, self._changeable_envs ) for envs in envs_list: + ckpt_path = tempfile.TemporaryDirectory() + envs["ckpt_path"] = ckpt_path.name self.run_test_case( "semi_auto_parallel_simple_net_dp_mp_pp.py", user_defined_envs=envs, ) + ckpt_path.cleanup() if __name__ == "__main__": diff --git a/test/auto_parallel/hybrid_strategy/testslist.csv b/test/auto_parallel/hybrid_strategy/testslist.csv index 8a9e3fe28e21c2..250c0b18e8ff9b 100644 --- a/test/auto_parallel/hybrid_strategy/testslist.csv +++ b/test/auto_parallel/hybrid_strategy/testslist.csv @@ -1,2 +1,3 @@ name,os,arch,timeout,run_type,launcher,num_port,run_serial,envs,conditions test_semi_auto_parallel_hybrid_strategy,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_save_load_state_dict,LINUX,GPU,120,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,