PyTorch实现逻辑回归(B站刘二大大练习题)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
分类问题:
比如手写数字识别,输出的是属于哪个数字的概率

import torchvision   这个数据集中提供示例数据集,download设置True会自动下载
# MNIST
train_set = torchvision.datasets.MNIST(root='../dataset/mnist', train=True, download=True)
test_set = torchvision.datasets.MNIST(root='../dataset/mnist', train=False, download=True)
# CIFAR 数据集torchvision.datasets.CIFAR10


# 比如二分类,使用sigmoid或其它函数
之前的仿射模型:y heat = x * w + b
逻辑回归模型:y heat = sigmoid(x*w + b)

损失函数也改变了:BCE损失(二分类的交叉熵)
loss = -(y*logy heat + (1-y)*log(1 - y heat))



那么Mini-Batch损失函数:
对它做均值

使用pytorch实现逻辑回归
"""


import torch
import torch.nn.functional as F


# 1.准备数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])  # 3行1列的tensor
y_data = torch.Tensor([[0], [0], [1]])


# 2.使用类来设计模型
class LogisticRegressionModel(torch.nn.Module):  # Module构造出来的对象,会自动构建反向传播过程
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)  # torch.nn.Linear构造一个对象,参数是权重和偏差,也是继承子Module的会自动进行反向传播
        # sigmoid中没有参数

    def forward(self, x):
        y_pred = F.sigmoid(self.linear(x))
        return y_pred


model = LogisticRegressionModel()  # model是callable的  model(x)
# 3.构建损失和优化器
criterion = torch.nn.BCELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 梯度下降 model.parameters()自动完成参数的初始化操作
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 梯度下降

# 4.训练
for epoch in range(1000):
    y_pred = model(x_data)  # 1.前向传播,计算y heat
    loss = criterion(y_pred, y_data)  # 2.计算损失
    print(epoch, loss)  # 打印

    optimizer.zero_grad()  # 梯度会自动计算,务必梯度清零!
    loss.backward()    # 3.反向传播
    optimizer.step()   # 4.更新 update

# 打印权重和偏置
print("w=", model.linear.weight.item())  # model下面的linear,下面的weight
print("b=", model.linear.bias.item())

# 测试模型
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print("y_pred=", y_test.data)
0%