> 文章列表 > cannot re-initialize CUDA in forked subproess

cannot re-initialize CUDA in forked subproess

cannot re-initialize CUDA in forked subproess

该问题描述了pytorch 中使用cuda 初始化时, 与 DataLoader 中使用num_worker 多进程, 这两者之间发生了冲突,

根据错误提示, 在DataLoader 中使用 multiprocessing_context 该参数,并设置为spawn, 由于平时创建时,使用fork 创建的是子线程, 所以没有注意。

1. cannot re-initialize CUDA in forked subproess 解决方法

import torch.multiprocessing as mpdef get_mean_and_std_4channel(dataset):'''Compute the mean and std value of dataset.'''mp.set_start_method('spawn') # set multiprocessing context to 'spawn'dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2, multiprocessing_context='spawn')mean = torch.zeros(4)std = torch.zeros(4)print('==> Computing the 9 channel  mean and std..')for inputs, targets in dataloader:for i in range(4):mean[i] += inputs[:, i, :, :].mean()std[i] += inputs[:, i, :, :].std()mean.div_(len(dataset))std.div_(len(dataset))return mean, std

1.1  原因解释

The error message you are seeing indicates that there is a conflict between PyTorch’s CUDA initialization and the use of multiprocessing in the DataLoader. One way to resolve this issue is to set the multiprocessing_context argument of the DataLoader to ‘spawn’ instead of the default ‘fork’ context. This will create a new process for the dataloader that does not conflict with the CUDA initialization.

In this modified code, we first import the torch.multiprocessing module and set the multiprocessing context to ‘spawn’ using the set_start_method function. We then pass the ‘spawn’ context to the DataLoader as the multiprocessing_context argument. This should resolve the CUDA initialization conflict and allow the DataLoader to run without errors.

多进程上下文: 是指多进程运行时的环境
在python 中可以使用 spawn, fork, forkserver, 不同的方式会决定如何创建子进程, 并且决定如何管理共享内存。

2 runtimeError context has already been set

若果在更改之后出现该问题,表明在代码的另外地方, 另外的进程已经设置了多进程的环境;
再次设置多进程时,便会出现错误;

2.1 解决方式1

在创建任何子进程之前,  确保多进程环境只被设置在主进程中;
可以通过 multiprocessing.current_process().name 来检查当前的进程是否是主进程;

import torch.multiprocessing as mpif __name__ == '__main__':if mp.current_process().name == 'MainProcess':mp.set_start_method('spawn')# rest of your code here

2.2 解决方式2

若是在第三方的库中,设置了多进程环境;
需要使用 已经存在的多进程上下文, 而不是在创建一个。
使用 get_context() 方法 赋值给 multiprocessing_context参数

import torch.multiprocessing as mpdef get_mean_and_std_4channel(dataset):'''Compute the mean and std value of dataset.'''dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2, multiprocessing_context=mp.get_context())mean = torch.zeros(4)std = torch.zeros(4)print('==> Computing the 9 channel  mean and std..')for inputs, targets in dataloader:for i in range(4):mean[i] += inputs[:, i, :, :].mean()std[i] += inputs[:, i, :, :].std()mean.div_(len(dataset))std.div_(len(dataset))return mean, std

2.3 解决方式3

num_workers = 0;

安陆市