来源:FC游戏站 更新:2020-11-12
用手机看
Grid AI 是用于在云上进行大规模训练模型的本机平台。该平台允许构建深度学习模型的研究者在大规模计算上进行迭代,然后将模型部署到可扩展环境中,该环境能够处理深度学习系统的最大流量。
2. 度量指标
pytorch_lightning.metrics 是一种 Metrics API,旨在在 PyTorch 和 PyTorch Lightning 中轻松地进行度量指标的开发和使用。更新后的 API 提供了一种内置方法,可针对每个步骤跨多个 GPU(进程)计算指标,同时存储统计信息。这可以让用户在一个阶段结束时计算指标,而无需担心任何与分布式后端相关的复杂度。
class LitModel(pl.LightningModule): def __init__(self): ... self.train_acc = pl.metrics.Accuracy() self.valid_acc = pl.metrics.Accuracy() def training_step(self, batch, batch_idx): logits = self(x) ... self.train_acc(logits, y) # log step metric self.log('train_acc_step', self.train_acc) def validation_step(self, batch, batch_idx): logits = self(x) ... self.valid_acc(logits, y) # logs epoch metrics self.log('valid_acc', self.valid_acc)
要实现自定义指标,只需将 Metric 基类子类化,FC游戏,并实现__init__()、update() 和 compute() 方法。用户需要做的就是正确调用 add_state(),以用 DDP 实现自定义指标。对使用 add_state() 添加的度量指标状态变量要调用 reset()。
from pytorch_lightning.metrics import Metric class MyAccuracy(Metric): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_stepdist_sync_on_step=dist_sync_on_step) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): preds, target = self._input_format(preds, target) assert preds.shape == target.shape self.correct += torch.sum(preds == target) self.total += target.numel() def compute(self): return self.correct.float() / self.total
3. 手动优化 VS 自动优化
使用 Lightning,用户不需要担心何时启用 / 停用 grad,只要从 training_step 中返回带有附加图的损失即可进行反向传播或更新优化器,Lightning 将会自动进行优化。
def training_step(self, batch, batch_idx): loss = self.encoder(batch[0]) return loss
但是,对于某些研究,如 GAN、强化学习或者是带有多个优化器或内部循环的某些研究,用户可以关闭自动优化,并完全由自己控制训练循环。
首先,关闭自动优化:
trainer *=* Trainer(automatic_optimization*=False*)
现在训练循环已经由用户自己掌握。
def training_step(self, batch, batch_idx, opt_idx): (opt_a, opt_b, opt_c) = self.optimizers() loss_a = self.generator(batch[0]) # use this instead of loss.backward so we can automate half # precision, etc... self.manual_backward(loss_a, opt_a, retain_graph=True) self.manual_backward(loss_a, opt_a) opt_a.step() opt_a.zero_grad() loss_b = self.discriminator(batch[0]) self.manual_backward(loss_b, opt_b) ...
4. Logging