代码之家  ›  专栏  ›  技术社区  ›  Aaditya Ura

使用AutoModelForSequenceClassification的Hugginface多类分类

  •  0
  • Aaditya Ura  · 技术社区  · 3 年前

    我试着用Hugginface的 AutoModelForSequenceClassification API用于多类分类,但对其配置感到困惑。

    我的数据集在一个热编码中,问题类型是多类(一次一个标签)

    我尝试过的:

    from transformers import AutoModelForSequenceClassification
    
    model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased",
                                                               num_labels=6,
                                                               id2label=id2label,
                                                               label2id=label2id)
    
    
    
    batch_size = 8
    metric_name = "f1"
    
    
    
    from transformers import TrainingArguments, Trainer
    
    args = TrainingArguments(
        f"bert-finetuned-english",
        evaluation_strategy = "epoch",
        save_strategy = "epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=10,
        weight_decay=0.01,
        load_best_model_at_end=True,
        metric_for_best_model=metric_name,
        #push_to_hub=True,
    )
    
    
    trainer = Trainer(
        model,
        args,
        train_dataset=encoded_dataset["train"],
        eval_dataset=encoded_dataset["test"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )
    

    这是正确的吗?

    我对损失函数感到困惑,当我打印一个正向传递时,损失是 BinaryCrossEntropyWithLogitsBackward

    SequenceClassifierOutput([('loss',
                               tensor(0.6986, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)),
                              ('logits',
                               tensor([[-0.5496,  0.0793, -0.5429, -0.1162, -0.0551]],
                                      grad_fn=<AddmmBackward0>))])
    

    其用于多标签或二进制分类任务。它应该使用“nn.CrossEntropyLoss”?如何正确使用这个API多类和定义损失函数?

    0 回复  |  直到 3 年前
        1
  •  0
  •   Anton Gorinenko    3 年前

    您有六个类,每个单元格中的值为1或0以进行编码。例如,张量[0.,0.,0.,0.,1.,0.]表示第五类。我们的任务是预测六个标签([1.,0.,0.,0..,0.]),并将它们与基本事实([0.,0.,0..、0..、1..、0.])进行比较。对于训练,我们使用损失函数BinaryCrossEntropyWithLogitsBackward