代码之家  ›  专栏  ›  技术社区  ›  Ketaki Kolhatkar

Blenderbot微调

  •  0
  • Ketaki Kolhatkar  · 技术社区  · 3 年前

    我一直在尝试微调拥抱脸(HuggingFace)的对话模式:Blendebot。我尝试了官方拥抱脸网站上给出的传统方法,该网站要求我们使用培训师进行。train()方法。我试着用。compile()方法。我在我的数据集上尝试了使用PyTorch和TensorFlow进行微调。这两种方法似乎都失败了,并给了我们一个错误,即Blenderbot模型没有称为compile或train的方法。我甚至在网上到处查看Blenderbot是如何根据我的自定义数据进行微调的,但它没有正确提到运行时不会出错。我看过Youtube教程、博客和StackOverflow帖子,但没有人回答这个问题。希望有人能在这里回应并帮助我。我也愿意使用其他拥抱式对话模型进行微调。

    这是我用来微调blenderbot模型的链接。

    微调方法: https://huggingface.co/docs/transformers/training

    Blenderbot: https://huggingface.co/docs/transformers/model_doc/blenderbot

    from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
    mname = "facebook/blenderbot-400M-distill"
    model = BlenderbotForConditionalGeneration.from_pretrained(mname)
    tokenizer = BlenderbotTokenizer.from_pretrained(mname)
    
    
    #FOR TRAINING: 
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=small_train_dataset,
        eval_dataset=small_eval_dataset,
        compute_metrics=compute_metrics,
    )
    trainer.train()
    
    #OR
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=tf.metrics.SparseCategoricalAccuracy(),
    )
    
    model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3)
    

    这些都不管用。

    0 回复  |  直到 3 年前
        1
  •  1
  •   AloneTogether    3 年前

    也许试着使用 TFBlenderbotForConditionalGeneration 的类 Tensorflow . 它有你需要的:

    import tensorflow as tf
    from transformers import BlenderbotTokenizer, TFBlenderbotForConditionalGeneration
    
    mname = "facebook/blenderbot-400M-distill"
    model = TFBlenderbotForConditionalGeneration.from_pretrained(mname)
    tokenizer = BlenderbotTokenizer.from_pretrained(mname)
    
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=tf.metrics.SparseCategoricalAccuracy(),
    )
    ....
    

    请参阅 docs 了解更多信息。

    推荐文章