我正在努力学习如何在Android上使用ML。我拿到了
Text Classification demo
正在工作,而且似乎工作得很好。然后我试着创建自己的模型。
我用来创建自己模型的代码是:
import numpy as np
import os
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.text_classifier import AverageWordVecSpec
from tflite_model_maker.text_classifier import DataLoader
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
spec = model_spec.get('mobilebert_classifier')
train_data = DataLoader.from_csv(
filename='/path to file/train.csv',
text_column='sentence',
label_column='label',
model_spec=spec,
is_training=True)
model = text_classifier.create(train_data, model_spec=spec, epochs=10)
model.export(export_dir='average_word_vec')
代码似乎运行良好,并创建了一个
model.tflite
帮我归档。然后我更换了演示
tflite
和我一起归档。但是当我运行演示时,我得到了以下错误:
java.lang.AssertionError: Error occurred when initializing NLClassifier: Type mismatch for input tensor serving_default_input_type_ids:0. Requested STRING, got INT32.
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.initJniWithByteBuffer(Native Method)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.access$100(NLClassifier.java:67)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier$2.createHandle(NLClassifier.java:223)
at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromBufferAndOptions(NLClassifier.java:219)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFileAndOptions(NLClassifier.java:175)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFile(NLClassifier.java:150)
at org.tensorflow.lite.examples.textclassification.client.TextClassificationClient.load(TextClassificationClient.java:44)
at org.tensorflow.lite.examples.textclassification.MainActivity.lambda$onStart$1$MainActivity(MainActivity.java:67)
at org.tensorflow.lite.examples.textclassification.-$$Lambda$MainActivity$eJaQnJq74KcmPEczFE5swJIGydg.run(Unknown Source:2)
我错过了什么?