代码如下:
import torch.nn as nn
class Model(nn.Module):
def __init__(self, input_size, output_size):
super(Model, self).__init__() # 继承父类的方法
self.fc = nn.Linear(input_size, output_size)
def forward(self, input):
output = self.fc(input)
return output
import torch
model = Model(10, 2)
print(next(model.parameters()).device)
device = torch.device('cuda:1')
model = model.to(device)
print(next(model.parameters()).device)
model = model.to('cuda:2')
print(next(model.parameters()).device)
model = model.to(3)
print(next(model.parameters()).device)
model.to(1)
和model.to('cuda:1')
效果一致。因篇幅问题不能全部显示,请点此查看更多更全内容
Copyright © 2019- awee.cn 版权所有 湘ICP备2023022495号-5
违法及侵权请联系:TEL:199 1889 7713 E-MAIL:2724546146@qq.com
本站由北京市万商天勤律师事务所王兴未律师提供法律服务