代码之家  ›  专栏  ›  技术社区  ›  aerin

model.train()在pytorch中做什么?

  •  12
  • aerin  · 技术社区  · 7 年前

    它叫什么 forward() 在里面 nn.Module ? 我想当我们叫模特的时候, forward 正在使用方法。 为什么需要指定train()?

    2 回复  |  直到 7 年前
        1
  •  35
  •   Umang Gupta    7 年前

    model.train() 告诉你的模特你正在训练模特。因此,有效的层,如辍学,BatchNorm等,在列车和测试程序上表现不同,知道发生了什么,因此可以相应的表现。

    更多细节: 它设置训练模式 (见 source code ). 您可以调用model.eval()或model.train(mode=False)来告诉您正在测试。 期待是有点直觉的 train 函数来训练模型,但它不这样做它只是设定了模式。

        2
  •  7
  •   prosti    6 年前

    这是密码 module.train() :

    def train(self, mode=True):
            r"""Sets the module in training mode."""      
            self.training = mode
            for module in self.children():
                module.train(mode)
            return self
    

    这里是 module.eval .

    def eval(self):
            r"""Sets the module in evaluation mode."""
            return self.train(False)
    

    模式 train eval 是我们唯一可以设置模块的两种模式,它们正好相反。

    那只是一个 self.training 标志和当前 只有 dropout bachnorm 关心那面旗帜。

    默认情况下,此标志设置为 True 是的。

        3
  •  1
  •   kelam gautam    7 年前

    有两种方法可以让模型知道您的意图,即您想培训模型还是想使用模型进行评估。 在model.train()的情况下,模型知道它必须学习层,当我们使用model.eval()时,它指示模型不需要学习任何新的内容,并且模型用于测试。 model.eval()也是必需的,因为在pytorch中,如果我们使用batchnorm,并且在测试期间,如果我们只想传递一个图像,那么如果未指定model.eval(),pytorch将抛出一个错误。