pytorch-lightning

1.pytorch-lightning总览

将pytorch主要训练过程封装起来

pytorch 和 pl 本质上代码是完全相同的。只不过pytorch需要自己造轮子(如model, dataloader, loss, train,test,checkpoint, save model等等都需要自己写),而pl 把这些模块都结构化了(类似keras)。

下面的例子很清晰的表明了两者的区别:

lighting model 是跨硬件的,因此可以移除所有cuda()和device

2.模型搭建:

2.1定义Lightning Module:

class LitAutoEncoder(pl.LightningModule):

def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 64),
nn.ReLU(),
nn.Linear(64, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28*28)
)

def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding

def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
# Logging to TensorBoard by default
self.log('train_loss', loss)
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# def init(self): 定义网络架构(model);
# def forward(self, x):定义推理、预测的前向传播;
# def training_step(self, batch, batch_idx):定义train loop;
# def configure_optimizers(self): 定义优化器

2.2 将dataloader写进DataModule中

class MNISTDataModule(pl.LightningDataModule):

def setup(self, stage):
# transforms for images
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])

# prepare transforms standard to MNIST
self.mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)

def val_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)

2.3 训练模型

# init model
autoencoder = LitAutoEncoder()

# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
# trainer = pl.Trainer(gpus=8) (if you have GPUs)
trainer = pl.Trainer()
trainer.fit(autoencoder, train_loader)
# 两个参数:model,datamodule

参考:https://blog.csdn.net/u014264373/article/details/117021901?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522169850200916800213059282%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=169850200916800213059282&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-117021901-null-null.142^v96^pc_search_result_base9&utm_term=PyTorch%20Lightning&spm=1018.2226.3001.4187