代码之家  ›  专栏  ›  技术社区  ›  loretoparisi

火炬模型由弦计算

  •  0
  • loretoparisi  · 技术社区  · 5 年前

    我对我的工作进行评估 torch 从文本文件加载批处理的模型:

    def batchify(data, bsz):
        nbatch = data.size(0)
        data = data.narrow(0, 0, nbatch * bsz)
        data = data.view(bsz, -1)
        return data
    
    def load_file(path, vocab, direction):
        lines = open(path).readlines()
        data = list(''.join(lines))
        idx = vocab['char'].map(data)
        if direction == 'backward': idx = idx[::-1]
        return torch.tensor(idx)
    
    def load_data(path, vocab, direction):
        data = load_file(path, vocab, direction)
        yield data
    

    这很好用:

    eval_file_or_dir = os.path.join(BASE_PATH,'shakespeare.txt')
    data = load_data(eval_file, vocab, direction)
    if isinstance(data, GeneratorType):
        data = list(data)
        data = data[0]
    batches = batchify(data, batch_size)
    

    我得到了 torch.Size([100, 6])

    现在,我想从字符串加载数据,所以我写了

    def load_text(text, vocab, direction):
        buf = io.StringIO(text)
        lines = buf.readlines()
        data = list(''.join(lines))
        idx = vocab['char'].map(data)
        if direction == 'backward': idx = idx[::-1]
        yield torch.tensor(idx)
    

    但它并没有像预期的那样起作用:

    data = load_text(text, vocab, direction)
    if isinstance(data, GeneratorType):
        data = list(data)
        data = data[0]
    batches = batchify(data, batch_size)
    print(batches, batches.size())
    

    我得到一个空的 tensor :

    tensor([], size=(100, 0), dtype=torch.int64) torch.Size([100, 0])
    
    0 回复  |  直到 5 年前