使用昇腾910B显卡训练ProPainter模型
1. 项目介绍
ProPainter是一个去掉视频里的静止和移动水印图像的AI项目。ProPainter项目地址:https://github.com/sczhou/ProPainter。
经过改写的资源包propainter_ascend20251029.zip支持使用昇腾910B显卡训练ProPainter,里面包含改写后的代码,已标注的训练数据(包括训练集和验证集)。
2. 安装训练环境
2.1 项目运行环境的主要库的版本
CANN:8.3
mindspore:2.7.1
numpy:1.26.0
python:3.11.14
matplotlib:3.7.0
torch:2.7.1
torchision:0.22.1
torch_npu:2.7.1
2.2 安装环境
conda create -n propainter python==3.11.14
conda activate propainter
pip3 install -r requirements.txt
pip3 install mindspore==2.7.1 -i https://repo.mindspore.cn/pypi/simple --trusted-host repo.mindspore.cn --extra-index-url https://repo.huaweicloud.cn/repository/pypi/simple
pip3 install torch==2.7.1
pip3 install matplotlib==3.7.0
pip3 install numpy==1.26.0
3. 训练模型
3.1 部署训练程序
mkdir -p /home/a/apply
解压propainter_ascend20251029.zip
unzip /home/a/apply/propainter_ascend20251029.zip -d /home/a/apply/
cd /home/a/apply/draw
3.2 调整模型训练参数
可以在train_propainter.json修改模型训练参数。
video_root:训练数据目录
batch_size:每轮训练数据量
num_workers:线程数量
3.3 训练模型
如果因显卡的显存不足而出现训练异常,调小batch_size,可以使小显存的显卡也可以训练
python train.py -c configs/train_propainter.json
4. 代码改写过程说明
4.1 通用的代码迁移方法
如果想使用昇腾910B显卡训练ProPainter,需要改写代码。使用的Ascend Extension for PyTorch(https://gitcode.com/Ascend/pytorch)的版本:7.1.0。
在原代码train.py添加这些代码:
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False
torch.npu.set_compile_mode(jit_compile=False)

各行代码功能说明:
1. from torch_npu.contrib import transfer_to_npu
· 从torch_npu的contrib模块导入transfer_to_npu函数
· 该函数用于将数据和模型传输到NPU设备上运行
2. torch.npu.config.allow_internal_format = False
· 配置NPU不使用内部格式
· 设置为False意味着使用标准格式,可能牺牲一些性能但确保更好的兼容性
3. torch.npu.set_compile_mode(jit_compile=False)
· 设置NPU编译模式,禁用JIT即时编译
· 关闭JIT编译可以避免一些编译时错误,便于调试,但可能影响运行效率
4.2 异常情况处理方法
报错1:
AttributeError: 'torch_npu._C._NPUDeviceProperties' object has no attribute 'multi_processor_count'
原因:算子不支持转译;
解决方案:改写相关代码,尽量不使用原函数相关的库。
报错2:
AttributeError: 'FigureCanvasAgg' object has no attribute 'tostring_rgb'
原因:matplotlib==3.11库不支持相关属性;
解决方案:安装matplotlib==3.7,这个版本可以支持相关属性。
报错3:
Traceback (most recent call last):
File "/data/disk3/bob20251028/apply/draw20251028/draw/train.py",
line 111, in <module> mp.spawn(main_worker, nprocs=8, args=(config,2))
原因:nprocs=8是多卡模式;
解决方案:将nprocs参数的值设置为1。

报错4:
Conv3dv2 only support static shape.
[FUNC:FusionII FILE:conv3d_to_conv3dv2_fusion_pass.cc LINE :409]
原因:算子不支持转译,torch.nn.Module相关;
解决方案:添加torch.npu.config.allow_internal_format = False和torch.npu.set_compile_mode(jit_compile=False)
- 点赞
- 收藏
- 关注作者
评论(0)