本文共 3892 字,大约阅读时间需要 12 分钟。
import torch# prepare dataset# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征x_data = torch.tensor([[1.0], [2.0], [3.0]])y_data = torch.tensor([[2.0], [4.0], [6.0]])
#design model using classclass LinearModel(torch.nn.Module): def __init__(self): super(LinearModel, self).__init__() # (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的 # 该线性层需要学习的参数是w和b 获取w/b的方式分别是~linear.weight/linear.bias self.linear = torch.nn.Linear(1, 1) def forward(self, x): y_pred = self.linear(x) return y_pred model = LinearModel()
# construct loss and optimizer# criterion = torch.nn.MSELoss(size_average = False)criterion = torch.nn.MSELoss(reduction = 'sum')optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # model.parameters()自动完成参数的初始化操作
# training cycle forward, backward, updatefor epoch in range(100): y_pred = model(x_data) # forward:predict loss = criterion(y_pred, y_data) # forward: loss print(epoch, loss.item()) optimizer.zero_grad() # the grad computer by .backward() will be accumulated. so before backward, remember set the grad to zero loss.backward() # backward: autograd,自动计算梯度 optimizer.step() # update 参数,即更新w和b的值 print('w = ', model.linear.weight.item())print('b = ', model.linear.bias.item()) x_test = torch.tensor([[4.0]])y_test = model(x_test)print('y_pred = ', y_test.data)
0 80.818939208984381 35.978454589843752 16.0167388916015623 7.1303482055664064 3.17437839508056645 1.413290262222296 0.62930101156234747 0.280290037393569958 0.124917775392532359 0.05574841052293777510 0.02495409548282623311 0.0112434029579162612 0.00513792317360639613 0.002417960204184055314 0.00120529637206345815 0.000663561047986149816 0.000420582800870761317 0.0003106167423538863718 0.0002598843711894005519 0.000235560204600915320 0.0002230154932476580121 0.0002157215640181675622 0.0002108066109940409723 0.0002069667098112404324 0.0002036372025031596425 0.0002005462447414174726 0.000197592366021126527 0.0001947227865457534828 0.0001919094211189076329 0.000189151105587370730 0.0001864261139417067231 0.0001837461022660136232 0.0001811022812034934833 0.0001784954656613990734 0.000175936991581693335 0.0001734078396111726836 0.0001709190692054107837 0.0001684581366134807538 0.000166042751516215539 0.000163650780450552740 0.0001613031199667602841 0.0001589856401551514942 0.0001566951104905456343 0.000154443943756632544 0.0001522216480225324645 0.00015003679436631546 0.0001478776393923908547 0.0001457534381188452248 0.0001436610036762431349 0.0001415899023413658150 0.00013956235488876751 0.0001375535211991518752 0.0001355809072265401553 0.0001336295972578227554 0.0001317124406341463355 0.000129817432025447556 0.0001279519347008317757 0.000126114318845793658 0.000124295824207365559 0.0001225149608217179860 0.0001207527238875627561 0.0001190143593703396662 0.0001173071577795781263 0.0001156185535364784364 0.0001139573942054994465 0.0001123199690482579266 0.0001107105854316614667 0.0001091119338525459268 0.000107544794445857469 0.000106004845292773170 0.0001044757154886610871 0.000102980106021277672 0.0001014950030366890173 0.0001000352058326825574 9.860043064691126e-0575 9.718433284433559e-0576 9.578699973644689e-0577 9.44119383348152e-0578 9.305298590334132e-0579 9.17144789127633e-0580 9.040001896210015e-0581 8.909945609048009e-0582 8.782246732152998e-0583 8.655633428134024e-0584 8.53120072861202e-0585 8.40857537696138e-0586 8.288319077109918e-0587 8.168668136931956e-0588 8.051737677305937e-0589 7.935728353913873e-0590 7.821589679224417e-0591 7.709144847467542e-0592 7.598604133818299e-0593 7.489340350730345e-0594 7.381867180811241e-0595 7.275665120687336e-0596 7.171095057856292e-0597 7.068333798088133e-0598 6.96621646056883e-0599 6.866479816380888e-05w = 2.005516290664673b = -0.012539949268102646y_pred = tensor([[8.0095]])
转载地址:http://gyali.baihongyu.com/