博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
5.PyTorch实现逻辑回归(二分类)
阅读量:4203 次
发布时间:2019-05-26

本文共 3892 字,大约阅读时间需要 12 分钟。

1 Prepare dataset

import torch# prepare dataset# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征x_data = torch.tensor([[1.0], [2.0], [3.0]])y_data = torch.tensor([[2.0], [4.0], [6.0]])

2 Design model using Class

#design model using classclass LinearModel(torch.nn.Module):    def __init__(self):        super(LinearModel, self).__init__()        # (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的        # 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.bias        self.linear = torch.nn.Linear(1, 1)     def forward(self, x):        y_pred = self.linear(x)        return y_pred model = LinearModel()

3 Construct loss and optimizer

# construct loss and optimizer# criterion = torch.nn.MSELoss(size_average = False)criterion = torch.nn.MSELoss(reduction = 'sum')optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # model.parameters()自动完成参数的初始化操作

4 Training cycle

# training cycle forward, backward, updatefor epoch in range(100):    y_pred = model(x_data) # forward:predict    loss = criterion(y_pred, y_data) # forward: loss    print(epoch, loss.item())     optimizer.zero_grad() # the grad computer by .backward() will be accumulated. so before backward, remember set the grad to zero    loss.backward() # backward: autograd,自动计算梯度    optimizer.step() # update 参数,即更新w和b的值 print('w = ', model.linear.weight.item())print('b = ', model.linear.bias.item()) x_test = torch.tensor([[4.0]])y_test = model(x_test)print('y_pred = ', y_test.data)
0 80.818939208984381 35.978454589843752 16.0167388916015623 7.1303482055664064 3.17437839508056645 1.413290262222296 0.62930101156234747 0.280290037393569958 0.124917775392532359 0.05574841052293777510 0.02495409548282623311 0.0112434029579162612 0.00513792317360639613 0.002417960204184055314 0.00120529637206345815 0.000663561047986149816 0.000420582800870761317 0.0003106167423538863718 0.0002598843711894005519 0.000235560204600915320 0.0002230154932476580121 0.0002157215640181675622 0.0002108066109940409723 0.0002069667098112404324 0.0002036372025031596425 0.0002005462447414174726 0.000197592366021126527 0.0001947227865457534828 0.0001919094211189076329 0.000189151105587370730 0.0001864261139417067231 0.0001837461022660136232 0.0001811022812034934833 0.0001784954656613990734 0.000175936991581693335 0.0001734078396111726836 0.0001709190692054107837 0.0001684581366134807538 0.000166042751516215539 0.000163650780450552740 0.0001613031199667602841 0.0001589856401551514942 0.0001566951104905456343 0.000154443943756632544 0.0001522216480225324645 0.00015003679436631546 0.0001478776393923908547 0.0001457534381188452248 0.0001436610036762431349 0.0001415899023413658150 0.00013956235488876751 0.0001375535211991518752 0.0001355809072265401553 0.0001336295972578227554 0.0001317124406341463355 0.000129817432025447556 0.0001279519347008317757 0.000126114318845793658 0.000124295824207365559 0.0001225149608217179860 0.0001207527238875627561 0.0001190143593703396662 0.0001173071577795781263 0.0001156185535364784364 0.0001139573942054994465 0.0001123199690482579266 0.0001107105854316614667 0.0001091119338525459268 0.000107544794445857469 0.000106004845292773170 0.0001044757154886610871 0.000102980106021277672 0.0001014950030366890173 0.0001000352058326825574 9.860043064691126e-0575 9.718433284433559e-0576 9.578699973644689e-0577 9.44119383348152e-0578 9.305298590334132e-0579 9.17144789127633e-0580 9.040001896210015e-0581 8.909945609048009e-0582 8.782246732152998e-0583 8.655633428134024e-0584 8.53120072861202e-0585 8.40857537696138e-0586 8.288319077109918e-0587 8.168668136931956e-0588 8.051737677305937e-0589 7.935728353913873e-0590 7.821589679224417e-0591 7.709144847467542e-0592 7.598604133818299e-0593 7.489340350730345e-0594 7.381867180811241e-0595 7.275665120687336e-0596 7.171095057856292e-0597 7.068333798088133e-0598 6.96621646056883e-0599 6.866479816380888e-05w =  2.005516290664673b =  -0.012539949268102646y_pred =  tensor([[8.0095]])

转载地址:http://gyali.baihongyu.com/

你可能感兴趣的文章
害怕自动化(1)
查看>>
深圳市软件质量提升工程系列活动——安全测试百人大课堂
查看>>
LoadRunner如何在脚本运行时修改log设置选项?
查看>>
QC数据库表结构
查看>>
自动化测试工具的3个关键部分
查看>>
测试工具厂商的编程语言什么时候“退休”?
查看>>
资源监控工具 - Hyperic HQ
查看>>
LoadRunner中Concurrent与Simultaneous的区别
查看>>
SiteScope - Agentless监控
查看>>
QTP的智能识别(Smart Identification)过程
查看>>
LoadRunner各协议所需耗费的内存资源表
查看>>
AutomatedQA收购Smart Bear?
查看>>
使用QTP进行WEB页面性能测试
查看>>
LoadRunner的VS.NET 2005插件
查看>>
LoadRunner中如何验证下载的文件大小、统计下载时间、度量下载速度?
查看>>
LoadRunner脚本评审Checklist
查看>>
在LoadRunner中设置HTTP请求time-out的时间
查看>>
在LoadRunner脚本中实现随机ThinkTime
查看>>
LoadRunner9.51中文帮助手册
查看>>
RPT录制问题
查看>>