Pytorch 训练完后的模型保存与加载

前言

上一节我们成功的训练了模型,现在我们加一行代码,保存模型

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split

#获取GPU设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

读取文件
df = pd.read_csv(
    'C:/Users/huzongsheng/Desktop/Iris.csv'
)

字符串替换

df['Species'] = df['Species'].map(
    {
        'Iris-setosa':0.0,
        'Iris-versicolor':1.0,
        'Iris-virginica':2.0
    }
)

#数据清洗与预处理

X=df.drop(['Species','Id'],axis=1)
y=df['Species']
X=X.values
y=y.values

#测试集与训练集切分

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train=torch.FloatTensor(X_train).to(device)
X_test=torch.FloatTensor(X_test).to(device)
y_train = torch.LongTensor(y_train).to(device)
y_test = torch.LongTensor(y_test).to(device)
print(X_train,X_test,y_train,y_test)

#定义神经网络

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4,10)
        self.fc2 = nn.Linear(10,10)
        self.fc3 = nn.Linear(10,3)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net().to(device)
#损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(model)

#训练模型
for epoch in range(10000):
    y_pred = model.forward(X_train).to(device)
    loss = criterion(y_pred, y_train).to(device)
    print(epoch,loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

这是上次的代码,我们加一行代码如下就保存了我们的模型了

模型的保存

#使用保存状态字典的方法是官方推荐的方法
PATH = 'model.pth'
torch.save(model.state_dict(), PATH)

PyTorch 模型将训练得到的参数存储在内部状态字典中,该字典名为 state_dict。这些参数可通过 torch.save 进行持久化保存。

模型的加载

要加载模型权重,您需要先创建同款模型的实例,然后使用 load_state_dict() 方法加载参数。

import torch
import torch.nn as nn
import torch.nn.functional as F


#创建同款模型

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4,10)
        self.fc2 = nn.Linear(10,10)
        self.fc3 = nn.Linear(10,3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model=Net()

模型加载

model.load_state_dict(torch.load("model.pth"))

model.eval()

print(model)
© 版权声明
THE END
喜欢就支持一下吧
点赞13赞赏 分享
评论 抢沙发

请登录后发表评论

    暂无评论内容