1.Pytorch hook 与 dataparallel 使用—— deoldify 源码解析 part1
Pytorch hook 与 dataparallel 使用—— deoldify 源码解析 part1
在调整项目deoldify从单GPU到多GPU训练时,遭遇了一系列问题,促使我对PyTorch的理解进一步加深。项目中的Unet结构在上采样过程中使用了skip connection,通常做法是硬编码实现,这种方式简洁明了,如何分辨网站源码和小程序源码但若需要改变网络结构,如从resnet调整为resnet,这样的硬编码方式显然不够灵活。
deoldify采取了另一种方法,在需要保存输出的网络层中插入自定义的hook函数,并利用PyTorch的register_forward_hook接口,确保每次前向传播时都能触发该函数,从而保存输出以供后续使用。自定义hook的tms 源码核心代码展示了这一过程。
在单GPU训练中,上述方法运行正常,然而在多GPU环境下,遇到了hook存储的值与concat操作的权重不在同一GPU设备上的问题,引发错误。起初,我误以为nn.DataParallel会自动处理这个问题,党建 源码但事实并非如此,我开始了深入的debug之旅。
首先,成功复现了错误现象,发现存储在hook中的值分布不均,部分在GPU1上,其他在GPU0上。onvif源码这表明nn.DataParallel并没有将hook备份并分发到每个GPU上,而是多个GPU共享同一个Hooks类及接口。进一步检查发现,不同线程对应的hook接口及存储值的内存地址相同,这证实了hook并不适用于多GPU运行环境。
为解决这一问题,参考了相关文献,nat 源码并将hook接口进行了修改,引入当前线程ID作为键,值对应输出,从而实现了线程安全。这一调整使得程序在第一个迭代周期正常运行。值得注意的是,第二个迭代周期又出现了问题,但这与hook的多线程运行无关,详情请见后续文章。
在debug过程中,为了简化操作,插入打印信息来观察多线程运行情况。然而,在获取hook中多线程运行信息时,遇到了异常,因获取`self.stored[key]`时报出`dict找不到key`的错误,这是因为多线程在写入`hook.stored`时,for循环期间警告`self.stored`的大小发生变化,这表明发生了并发错误,部分值并未正确写入。最终,删除了打印代码,程序恢复正常运行。
本次经历不仅解决了多GPU环境下hook使用的问题,也加深了我对PyTorch多GPU运行机制的理解,特别是关于线程安全和并发操作的注意事项。