Skip to content

Commit 5236d47

Browse files
authored
Fix state dict api en doc (#59793) (#60034)
* exclude xpu * fix save and load doc and not expose api in checkpoint dir * skip example * fix doc
1 parent dcb37aa commit 5236d47

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

python/paddle/distributed/checkpoint/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from .save_state_dict import save_state_dict
16-
from .load_state_dict import load_state_dict
17-
18-
__all__ = [
19-
"save_state_dict",
20-
"load_state_dict",
21-
]

python/paddle/distributed/checkpoint/load_state_dict.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,14 +360,17 @@ def load_state_dict(
360360
) -> None:
361361
"""
362362
Load the state_dict inplace from a checkpoint path.
363+
363364
Args:
364365
state_dict(Dict[str, paddle.Tensor]): The state_dict to load. It will be modified inplace after loading.
365366
path(str): The directory to load checkpoint files.
366367
process_group(paddle.distributed.collective.Group): ProcessGroup to be used for cross-rank synchronization. Use the default process group which contains all cards.
367368
coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default.
369+
368370
Example:
369371
.. code-block:: python
370-
>>> # doctest: +SKIP('Load state dict.')
372+
373+
>>> # doctest: +SKIP('run in distributed mode.')
371374
>>> import paddle
372375
>>> import paddle.distributed as dist
373376
>>> ckpt_path = "./checkpoint"

python/paddle/distributed/checkpoint/save_state_dict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def save_state_dict(
8888
8989
Examples:
9090
.. code-block:: python
91-
>>> # doctest: +SKIP('Save state dict.')
91+
92+
>>> # doctest: +SKIP('run in distributed mode')
9293
>>> import paddle
9394
>>> import paddle.distributed as dist
9495
>>> w1 = paddle.arange(32).reshape([4, 8])

0 commit comments

Comments
 (0)