代码之家  ›  专栏  ›  技术社区  ›  casolorz

初始化NLClassifier时出错:为\u default\u input\u Type\u id服务的输入张量的类型不匹配:0。请求字符串,获取INT32

  •  0
  • casolorz  · 技术社区  · 4 年前

    我正在努力学习如何在Android上使用ML。我拿到了 Text Classification demo 正在工作,而且似乎工作得很好。然后我试着创建自己的模型。

    我用来创建自己模型的代码是:

    import numpy as np
    import os
    
    from tflite_model_maker import model_spec
    from tflite_model_maker import text_classifier
    from tflite_model_maker.config import ExportFormat
    from tflite_model_maker.text_classifier import AverageWordVecSpec
    from tflite_model_maker.text_classifier import DataLoader
    
    import tensorflow as tf
    assert tf.__version__.startswith('2')
    tf.get_logger().setLevel('ERROR')
    
    spec = model_spec.get('mobilebert_classifier')
    
    train_data = DataLoader.from_csv(
        filename='/path to file/train.csv',
        text_column='sentence',
        label_column='label',
        model_spec=spec,
        is_training=True)
    
    model = text_classifier.create(train_data, model_spec=spec, epochs=10)
    
    model.export(export_dir='average_word_vec')
    

    代码似乎运行良好,并创建了一个 model.tflite 帮我归档。然后我更换了演示 tflite 和我一起归档。但是当我运行演示时,我得到了以下错误:

     java.lang.AssertionError: Error occurred when initializing NLClassifier: Type mismatch for input tensor serving_default_input_type_ids:0. Requested STRING, got INT32.
            at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.initJniWithByteBuffer(Native Method)
            at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.access$100(NLClassifier.java:67)
            at org.tensorflow.lite.task.text.nlclassifier.NLClassifier$2.createHandle(NLClassifier.java:223)
            at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
            at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromBufferAndOptions(NLClassifier.java:219)
            at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFileAndOptions(NLClassifier.java:175)
            at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFile(NLClassifier.java:150)
            at org.tensorflow.lite.examples.textclassification.client.TextClassificationClient.load(TextClassificationClient.java:44)
            at org.tensorflow.lite.examples.textclassification.MainActivity.lambda$onStart$1$MainActivity(MainActivity.java:67)
            at org.tensorflow.lite.examples.textclassification.-$$Lambda$MainActivity$eJaQnJq74KcmPEczFE5swJIGydg.run(Unknown Source:2)
    

    我错过了什么?

    0 回复  |  直到 4 年前
        1
  •  1
  •   user3152729    4 年前

    在你的代码中,你训练了一个MobileBERT模型,但保存到了平均单词向量的路径? spec=model_spec.get('mobilebert_classifier') 模型导出(export_dir='average_word_vec')

    一种可能性是:您使用的是average_word_vec模型,但添加了MobileBERT元数据,因此预处理不匹配。

    你能按照Model Maker教程再试一次吗? https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb 确保更改导出路径。