代码之家  ›  专栏  ›  技术社区  ›  twhale

如何查找和解决加载.txt文件时发生的索引错误?蟒蛇

  •  0
  • twhale  · 技术社区  · 7 年前

    我正在努力训练机器翻译的序列到序列模型。我用的是公共的 .txt 包含两列的数据集,从英语到德语(每行一对,用制表符分隔语言): http://www.manythings.org/anki/deu-eng.zip 这很管用。但是,在尝试使用自己的数据集时遇到了一个问题。

    我自己的 DataFrame 如下所示:

        Column 1    Column 2
    0   English a   German a
    1   English b   German b
    2   English c   German c
    3   English d   German d
    4   ...         ...
    

    要在同一脚本中使用它,我将保存此 数据帧 .txt文件 文件如下(目标是每行再获得一对,用一个标签分隔语言):

    df.to_csv("dataset.txt", index=False, sep='\t')

    清除数据的代码中出现问题:

    # load doc into memory
    def load_doc(filename):
    # open the file as read only
        file = open(filename, mode='rt', encoding='utf-8')
        # read all text
        text = file.read()
        # close the file
        file.close()
        return text
    
    # split a loaded document into sentences
    def to_pairs(doc):
        lines = doc.strip().split('\n')
        pairs = [line.split('\t') for line in lines]  
    
    # clean a list of lines
     def clean_pairs(lines):
        cleaned = list()
        # prepare regex for char filtering
        re_print = re.compile('[^%s]' % re.escape(string.printable))
        # prepare translation table for removing punctuation
        table = str.maketrans('', '', string.punctuation)
        for pair in lines:
            clean_pair = list()
            for line in pair:
                # normalize unicode characters
                line = normalize('NFD', line).encode('ascii', 'ignore')
                line = line.decode('UTF-8')     
                # tokenize on white space
                line = line.split()
                # convert to lowercase
                line = [word.lower() for word in line]       
                # remove punctuation from each token
                line = [word.translate(table) for word in line]       
                # remove non-printable chars form each token
                line = [re_print.sub('', w) for w in line]                 
                # remove tokens with numbers in them
                line = [word for word in line if word.isalpha()]           
                # store as string
                clean_pair.append(' '.join(line))
    #            print(clean_pair)
            cleaned.append(clean_pair)
    #        print(cleaned)
        print(array(cleaned))
        return array(cleaned) # something goes wrong here
    
    # save a list of clean sentences to file
    def save_clean_data(sentences, filename):
        dump(sentences, open(filename, 'wb'))
        print('Saved: %s' % filename)
    
    # load dataset
    filename = 'data/dataset.txt'
    doc = load_doc(filename)
    # split into english-german pairs
    pairs = to_pairs(doc)
    # clean sentences
    clean_pairs = clean_pairs(pairs)
    # save clean pairs to file
    save_clean_data(clean_pairs, 'english-german.pkl')
    # spot check
     for i in range(100):
        print('[%s] => [%s]' % (clean_pairs[i,0], clean_pairs[i,1]))
    

    最后一行引发以下错误:

    IndexError                          Traceback (most recent call last)
    <ipython-input-2-052d883ebd4c> in <module>()
         72 # spot check
         73 for i in range(100):
    ---> 74     print('[%s] => [%s]' % (clean_pairs[i,0], clean_pairs[i,1]))
         75 
         76 # load a clean dataset
    
    IndexError: too many indices for array
    

    一个奇怪的事情是,对于标准数据集和我自己的数据集,以下行的输出是不同的:

    # Standard dataset:
    return array(cleaned)
    [['hi' 'hallo']
     ['hi' 'gru gott']
     ['run' ‘lauf’]]
    
    # My own dataset:
    return array(cleaned)
    [list(['hi' 'hallo'])
     list(['hi' 'gru gott'])
     list(['run' ‘lauf’])]
    

    有人能解释问题是什么以及如何解决这个问题吗?

    1 回复  |  直到 7 年前
        1
  •  0
  •   Iguananaut    7 年前

    clean_pairs list clean_pairs[i,0] clean_pairs[i][0]

    to_pairs

    推荐文章