读写 Checkpoint
更新时间:2026-05-14
bostorchconnector 支持直接对存储在 BOS 上的 checkpoint 做读写。
以 resnet 在某次 epoch 的结果写入和加载 checkpoint 为例:
Python
1from bostorchconnector import BosCheckpoint, BosClientConfig
2
3import torchvision
4import torch
5
6# 填充 <BUCKET>、<KEY>和对应的endpoint
7CHECKPOINT_URI="bos://<BUCKET>/<KEY>/"
8ENDPOINT="http://bj.bcebos.com"
9config = BosClientConfig(log_level=1)
10checkpoint = BosCheckpoint(endpoint=ENDPOINT, bos_client_config=config)
11
12model = torchvision.models.resnet18()
13
14# 保存checkpoint到Bos
15with checkpoint.writer(CHECKPOINT_URI + "epoch0.ckpt") as writer:
16 torch.save(model.state_dict(), writer)
17
18# 从Bos读取checkpoint
19with checkpoint.reader(CHECKPOINT_URI + "epoch0.ckpt") as reader:
20 state_dict = torch.load(reader)
21
22model.load_state_dict(state_dict)
分布式Checkpoint(DCP)
bostorchconnector 提供了对 PyTorch 分布式 Checkpoint 的支持,包括:
- BosStorageWriter:实现了 PyTorch 的
StorageWriter接口。 - BosStorageReader:实现了 PyTorch 的
StorageReader接口。 - BosFileSystem:实现了 PyTorch 的
FileSystemBase接口。
这些工具实现了 Bos 与 PyTorch 分布式 Checkpoint 的无缝集成,支持高效存储和读取分布式模型 Checkpoint。
前置条件与安装
需要 PyTorch 2.3 或更新版本。安装时需要指定 dcp 额外依赖:
Shell
1pip install bostorchconnector[dcp]
示例
Python
1from bostorchconnector.dcp import BosStorageWriter, BosStorageReader
2from bostorchconnector import BosClientConfig
3
4import torchvision
5import torch.distributed.checkpoint as DCP
6
7# 配置
8CHECKPOINT_URI = "bos://<BUCKET>/<KEY>/"
9ENDPOINT = "http://bj.bcebos.com"
10
11model = torchvision.models.resnet18()
12
13# 自定义配置参数:credentials_path、log_level、log_path、part_size、pool_threads_num、max_attempts
14config = BosClientConfig(part_size = 16 * 1024 * 1024, pool_threads_num = 64)
15
16# 保存分布式 Checkpoint 到 Bos
17bos_storage_writer = BosStorageWriter(
18 endpoint=ENDPOINT,
19 path=CHECKPOINT_URI,
20 bos_client_config=config, # 可选
21 thread_count=4, # 可选,写入时使用的 IO 线程数
22 overwrite=True,
23)
24DCP.save(
25 state_dict=model.state_dict(),
26 storage_writer=bos_storage_writer,
27)
28
29# 从 Bos 加载分布式 Checkpoint
30model = torchvision.models.resnet18()
31model_state_dict = model.state_dict()
32bos_storage_reader = BosStorageReader(
33 endpoint=ENDPOINT,
34 path=CHECKPOINT_URI,
35 bos_client_config=config, # 可选
36)
37DCP.load(
38 state_dict=model_state_dict,
39 storage_reader=bos_storage_reader,
40)
41model.load_state_dict(model_state_dict)
PyTorch Lightning
bostorchconnector 包含了对 PyTorch Lightning 的集成,提供了 BosLightningCheckpoint,它实现了 Lightning 的 CheckpointIO 接口。用户可以借此在 PyTorch Lightning 中使用 Bos 进行 Checkpoint 的读写。
安装
Shell
1pip install bostorchconnector[lightning]
示例
Python
1from lightning import Trainer
2from bostorchconnector.lightning import BosLightningCheckpoint
3
4from fsspec.registry import register_implementation
5import bosfs
6
7# ...
8
9CHECKPOINT_URI = "bos://<BUCKET>/<KEY>/"
10ENDPOINT = "http://bj.bcebos.com"
11
12save_only_latest = True
13
14register_implementation("bos", bosfs.BOSFileSystem)
15
16dataset = WikiText2()
17dataloader = DataLoader(dataset, num_workers=2)
18
19model = LightningTransformer(vocab_size=dataset.vocab_size)
20bos_lightning_checkpoint = BosLightningCheckpoint(endpoint=ENDPOINT)
21
22checkpoint_callback = ModelCheckpoint(
23 dirpath=CHECKPOINT_URI,
24 save_top_k=1 if save_only_latest else -1,
25 every_n_train_steps=1,
26 filename="checkpoint-{epoch:02d}-{step:02d}",
27 enable_version_counter=True,
28)
29
30trainer = Trainer(
31 plugins=[bos_lightning_checkpoint],
32 callbacks=[checkpoint_callback],
33 min_epochs=4,
34 max_epochs=5,
35 max_steps=3,
36)
37trainer.fit(model, dataloader)
38
39# read
40r_trainer = Trainer(
41 plugins=[bos_lightning_checkpoint],
42 min_epochs=4,
43 max_epochs=5,
44 max_steps=3,
45)
46# Load the checkpoint in `ckpt_path` before training
47r_trainer.fit(model, dataloader, ckpt_path=CHECKPOINT_URI + "checkpoint-epoch=00-step=03.ckpt")
评价此篇文章
