♊ 双人成行
题目
请你运用数据并行化,使用 PyTorch 的 DistributedDataParallel
模组,实现一个多机或多卡的对 CIFAR-10
进行分类的神经网络,模型自选。
题解
没有两个 GPU,用 CPU 和 GPU 并行凑合了(
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
devices = [torch.device('cpu'), torch.device('cuda')]
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.conv4 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.conv5 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.dense = nn.Sequential(
nn.Linear(128, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = x.view(-1, 128)
x = self.dense(x)
return x
def train(rank, world_size):
dist.init_process_group("gloo", rank=rank, world_size=world_size)
torch.manual_seed(42)
device = devices[rank]
net = Net().to(device)
net = DistributedDataParallel(net)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
trainset = torchvision.datasets.CIFAR10(root='./cifar-10', train=True, download=True,
transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=5, shuffle=True)
CHECKPOINT = "model.checkpoint.pkl"
if rank == 0:
torch.save(net.state_dict(), CHECKPOINT)
dist.barrier()
net.load_state_dict(torch.load(CHECKPOINT, map_location=device))
for index, data in enumerate(trainloader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels).to(device)
loss.backward()
optimizer.step()
dist.barrier()
if rank == 0:
os.remove(CHECKPOINT)
torch.save(net, f'./model.pkl')
dist.destroy_process_group()
if __name__ == '__main__':
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
mp.freeze_support()
mp.spawn(train,
args=(len(devices), ),
nprocs=len(devices),
join=True)