Pytorch 02 (官方教程)

Tensor

是类似数组或矩阵(多维数组)的数据结构,在 pytorch 会将模型的输入和输出都转换为一个 tensor 。

Dataset

即数据集,用于给模型训练或进行预测。 DataLoader 则是对 Dataset 的封装,用于方便的使用数据集。

Transforms

用于对数据进行一些处理,使之更加适合被用于训练模型,因为原始数据的格式很可能不满足我们的需求。

构建神经网络

通过创建 nn.Module 的子类定义一个神经网络。在 __init__ 中初始化网络层级结构。每个 nn.Module 的子类还需要实现处理输入数据的 forward 方法。

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x): # 不要直接调用该方法
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)