yuyi
知不可乎骤得,托遗响于悲风
记录自己使用 pytorch 进行深度学习过程中遇到的一些问题和解决办法
㊙ 这些解决方法未必是最佳方案
使用单机多显卡时,可能每张显卡的显存大小不是一致的。而 DataParallel 默认是根据显卡的数量对 batchsize 进行均分,每张显卡都会被分配到一样大小的显存消耗。
修改 torch > nn > Parallel > scatter_gather.py
的 scatter_map
函数如下
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
# 获得 batchsize 的数量
batch_size = obj.shape[0]
# 将 2/3 的 batchsize 传入第一个 gpu, 剩下的分给第二个
num1 = 2*batch_size//3
return Scatter.apply(target_gpus, [num1, batch_size-num1], dim, obj)
...