pytorch-lightning 1.pytorch-lightning总览 将pytorch主要训练过程封装起来
pytorch 和 pl 本质上代码是完全相同的。只不过pytorch需要自己造轮子(如model, dataloader, loss, train,test,checkpoint, save model等等都需要自己写),而pl 把这些模块都结构化了(类似keras)。
下面的例子很清晰的表明了两者的区别:
lighting model 是跨硬件的,因此可以移除所有cuda()和device
2.模型搭建: 2.1定义Lightning Module: class LitAutoEncoder (pl.LightningModule): def __init__ (self ): super ().__init__() self.encoder = nn.Sequential( nn.Linear(28 *28 , 64 ), nn.ReLU(), nn.Linear(64 , 3 ) ) self.decoder = nn.Sequential( nn.Linear(3 , 64 ), nn.ReLU(), nn.Linear(64 , 28 *28 ) ) def forward (self, x ): embedding = self.encoder(x) return embedding def training_step (self, batch, batch_idx ): x, y = batch x = x.view(x.size(0 ), -1 ) z = self.encoder(x) x_hat = self.decoder(z) loss = F.mse_loss(x_hat, x) self.log('train_loss' , loss) return loss def configure_optimizers (self ): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3 ) return optimizer
2.2 将dataloader写进DataModule中 class MNISTDataModule (pl.LightningDataModule): def setup (self, stage ): transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307 ,), (0.3081 ,))]) self.mnist_train = MNIST(os.getcwd(), train=True , download=True , transform=transform) self.mnist_test = MNIST(os.getcwd(), train=False , download=True , transform=transform) def train_dataloader (self ): return DataLoader(self.mnist_train, batch_size=64 ) def val_dataloader (self ): return DataLoader(self.mnist_test, batch_size=64 )
2.3 训练模型 autoencoder = LitAutoEncoder() trainer = pl.Trainer() trainer.fit(autoencoder, train_loader)
参考:https://blog.csdn.net/u014264373/article/details/117021901?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522169850200916800213059282%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=169850200916800213059282&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~top_positive~default-1-117021901-null-null.142^v96^pc_search_result_base9&utm_term=PyTorch%20Lightning&spm=1018.2226.3001.4187