代码之家  ›  专栏  ›  技术社区  ›  Siddhivinayak

无法从android\u资产访问张量模型

  •  0
  • Siddhivinayak  · 技术社区  · 6 年前

    ClassifierActivity

    我已经在assets文件夹中添加了模型文件 enter image description here

    enter image description here

    我的 TensorFlowImageClassifier

      public class TensorFlowImageClassifier implements Classifier {
      private static final String TAG = "TensorFlowImageClassifier";
    
      // Only return this many results with at least this confidence.
      private static final int MAX_RESULTS = 2;
      private static final float THRESHOLD = 0.1f;
    
      // Config values.
      private String inputName;
      private String outputName;
      private int inputSize;
      private int imageMean;
      private float imageStd;
    
      // Pre-allocated buffers.
      private Vector<String> labels = new Vector<String>();
      private int[] intValues;
      private float[] floatValues;
      private float[] outputs;
      private String[] outputNames;
    
      private boolean logStats = false;
    
      private TensorFlowInferenceInterface inferenceInterface;
    
      private TensorFlowImageClassifier() {}
    
      public static Classifier create(AssetManager assetManager, String modelFilename, String labelFilename, int inputSize,
      int imageMean,
      float imageStd,
      String inputName,
      String outputName) {
        TensorFlowImageClassifier c = new TensorFlowImageClassifier();
        c.inputName = inputName;
        c.outputName = outputName;
        // Read the label names into memory.
        // TODO(andrewharp): make this handle non-assets.
        String actualFilename = labelFilename.split("file:///android_asset/")[1];
        Log.i(TAG, "Reading labels from: " + actualFilename);
        BufferedReader br = null;
        try {
          br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));
          String line;
          while ((line = br.readLine()) != null) {
            c.labels.add(line);
          }
          br.close();
        } catch (IOException e) {
          throw new RuntimeException("Problem reading label file!" , e);
        }
    
        c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
    
        // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
        final Operation operation = c.inferenceInterface.graphOperation(outputName);
        final int numClasses = (int) operation.output(0).shape().size(1);
        Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
    
        // Ideally, inputSize could have been retrieved from the shape of the input operation.  Alas,
        // the placeholder node for input in the graphdef typically used does not specify a shape, so it
        // must be passed in as a parameter.
        c.inputSize = inputSize;
        c.imageMean = imageMean;
        c.imageStd = imageStd;
    
        // Pre-allocate buffers.
        c.outputNames = new String[] {outputName};
        c.intValues = new int[inputSize * inputSize];
        c.floatValues = new float[inputSize * inputSize * 3];
        c.outputs = new float[numClasses];
    
        return c;
      }
    
      @Override public List<Recognition> recognizeImage(final Bitmap bitmap) {
        // Log this method so that it can be analyzed with systrace.
        Trace.beginSection("recognizeImage");
    
        Trace.beginSection("preprocessBitmap");
        // Preprocess the image data from 0-255 int to normalized float based
        // on the provided parameters.
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        for (int i = 0; i < intValues.length; ++i) {
          final int val = intValues[i];
          floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd;
          floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd;
          floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd;
        }
        Trace.endSection();
    
        // Copy the input data into TensorFlow.
        Trace.beginSection("feed");
        inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
        Trace.endSection();
    
        // Run the inference call.
        Trace.beginSection("run");
        inferenceInterface.run(outputNames, logStats);
        Trace.endSection();
    
        // Copy the output Tensor back into the output array.
        Trace.beginSection("fetch");
        inferenceInterface.fetch(outputName, outputs);
        Trace.endSection();
    
        // Find the best classifications.
        PriorityQueue<Recognition> pq =
            new PriorityQueue<Recognition>(
                3,
                new Comparator<Recognition>() {
                  @Override
                  public int compare(Recognition lhs, Recognition rhs) {
                    // Intentionally reversed to put high confidence at the head of the queue.
                    return Float.compare(rhs.getConfidence(), lhs.getConfidence());
                  }
                });
        for (int i = 0; i < outputs.length; ++i) {
          if (outputs[i] > THRESHOLD) {
            pq.add(
                new Recognition(
                    "" + i, labels.size() > i ? labels.get(i) : "unknown", outputs[i], null));
          }
        }
        final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
        int recognitionsSize = Math.min(pq.size(), MAX_RESULTS);
        for (int i = 0; i < recognitionsSize; ++i) {
          recognitions.add(pq.poll());
        }
        Trace.endSection(); // "recognizeImage"
        return recognitions;
      }
    
      @Override public void enableStatLogging(boolean debug) {
        this.logStats = logStats;
      }
    
      @Override public String getStatString() {
        return inferenceInterface.getStatString();
      }
    
      @Override public void close() {
        inferenceInterface.close();
      }
    }
    
    1 回复  |  直到 6 年前
        1
  •  0
  •   Siddhivinayak    6 年前

    好吧,答案很愚蠢,我犯的错误是在下面提到的变量中指定了其他路径 MODEL_FILE LABEL_FILE 出于某种原因,当提到其他回购网上反正我写了一个正确的下面。很抱歉提出这么愚蠢的问题

    private static final String MODEL_FILE = "file:///android_asset/graph.pb";
    private static final String LABEL_FILE = "file:///android_asset/labels.txt";