代码之家  ›  专栏  ›  技术社区  ›  Fabio Picchi

joblib转储上的MemoryError

  •  2
  • Fabio Picchi  · 技术社区  · 7 年前

    我运行了以下代码片段来训练文本分类模型。我对它进行了大量的优化,运行非常平稳,但仍然使用了大量RAM。我们的数据集是巨大的(1300万个文档+词汇表中的1800万个单词),但在我看来,执行点抛出错误是非常奇怪的。脚本:

    encoder = LabelEncoder()
    y = encoder.fit_transform(categories)
    classes = list(range(0, len(encoder.classes_)))
    
    vectorizer = CountVectorizer(vocabulary=vocabulary,
                                 binary=True,
                                 dtype=numpy.int8)
    
    classifier = SGDClassifier(loss='modified_huber',
                               n_jobs=-1,
                               average=True,
                               random_state=1)
    
    tokenpath = modelpath.joinpath("tokens")
    for i in range(0, len(batches)):
        token_matrix = joblib.load(
            tokenpath.joinpath("{}.pickle".format(i)))
        batchsize = len(token_matrix)
        classifier.partial_fit(
            vectorizer.transform(token_matrix),
            y[i * batchsize:(i + 1) * batchsize],
            classes=classes
        )
    
    joblib.dump(classifier, modelpath.joinpath('classifier.pickle'))
    joblib.dump(vectorizer, modelpath.joinpath('vectorizer.pickle'))
    joblib.dump(encoder, modelpath.joinpath('category_encoder.pickle'))
    joblib.dump(options, modelpath.joinpath('extraction_options.pickle'))
    

    我在这行找到了记忆错误:

    joblib.dump(vectorizer, modelpath.joinpath('vectorizer.pickle'))
    

    在执行的这一点上,训练完成,分类器已经转储。它应该由垃圾收集器收集,以防需要更多内存。除此之外,如果joblib的内存不是 compressing the data

    我对python垃圾收集器的内部工作原理没有深入的了解。我应该强制gc吗。collect()或使用“del”语句释放不再需要的对象?

    更新时间:

    我尝试过使用HashingVectorier,尽管它大大减少了内存使用量,但矢量化速度要慢得多,因此它不是一个很好的替代方法。

    我必须对矢量器进行pickle处理,以便稍后在分类过程中使用它,以便生成提交给分类器的稀疏矩阵。我将在此处发布我的分类代码:

    extracted_features = joblib.Parallel(n_jobs=-1)(
        joblib.delayed(features.extractor) (d, extraction_options) for d in documents)
    
    probabilities = classifier.predict_proba(
        vectorizer.transform(extracted_features))
    
    predictions = category_encoder.inverse_transform(
        probabilities.argmax(axis=1))
    
    trust = probabilities.max(axis=1)
    
    1 回复  |  直到 7 年前
        1
  •  2
  •   krassowski    7 年前

    如果您正在向 CountVectorizer ,以后在分类过程中重新创建它应该不会有问题。由于您提供的是一组字符串而不是映射,因此可能需要使用已解析的词汇表,您可以通过以下方式访问该词汇表:

    parsed_vocabulary = vectorizer.vocabulary_
    joblib.dump(parsed_vocabulary, modelpath.joinpath('vocabulary.pickle'))
    

    然后加载它并用于重新创建 计数矢量器 :

    vectorizer = CountVectorizer(
        vocabulary=parsed_vocabulary,
        binary=True,
        dtype=numpy.int8
    )
    

    注意,这里不需要使用joblib;标准酸洗应执行相同的操作;使用任何可用的替代方法都可能获得更好的结果,值得一提的是PyTables。

    如果这也占用了大量内存,您应该尝试使用原始 vocabulary 用于矢量器的再生;目前,当提供一组字符串作为词汇表时,向量器只是将集合转换为排序列表,因此您不必担心再现性(尽管在生产中使用之前我会仔细检查)。或者您可以自己将集合转换为列表。

    总而言之:因为你没有 fit() 矢量器,使用 计数矢量器 是它的 transform() 方法由于所需的全部数据都是词汇表(和参数),因此您可以减少内存消耗,无论是否处理词汇表。

    由于您要求从官方渠道获得答案,我想向您指出: https://github.com/scikit-learn/scikit-learn/issues/3844 其中,scikit的所有者和贡献者提到重新创建 计数矢量器 ,尽管是出于其他目的。在链接的repo中报告问题可能会更好,但请确保包含会导致过度内存使用问题的数据集,以使其可复制。

    最后你可以使用 HashingVectorizer 正如前面在评论中提到的。

    PS:关于使用 gc.collect() -在这种情况下,我会试一试;关于技术细节,您将发现许多关于如何解决此问题的问题。