Pytorch 02 (官方教程)
Tensor
是类似数组或矩阵(多维数组)的数据结构,在 pytorch 会将模型的输入和输出都转换为一个 tensor 。
Dataset
即数据集,用于给模型训练或进行预测。 DataLoader 则是对 Dataset 的封装,用于方便的使用数据集。
Transforms
用于对数据进行一些处理,使之更加适合被用于训练模型,因为原始数据的格式很可能不满足我们的需求。
构建神经网络
通过创建 nn.Module
的子类定义一个神经网络。在 __init__
中初始化网络层级结构。每个 nn.Module
的子类还需要实现处理输入数据的 forward
方法。
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x): # 不要直接调用该方法
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)