我对我的工作进行评估
torch
从文本文件加载批处理的模型:
def batchify(data, bsz):
nbatch = data.size(0)
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1)
return data
def load_file(path, vocab, direction):
lines = open(path).readlines()
data = list(''.join(lines))
idx = vocab['char'].map(data)
if direction == 'backward': idx = idx[::-1]
return torch.tensor(idx)
def load_data(path, vocab, direction):
data = load_file(path, vocab, direction)
yield data
这很好用:
eval_file_or_dir = os.path.join(BASE_PATH,'shakespeare.txt')
data = load_data(eval_file, vocab, direction)
if isinstance(data, GeneratorType):
data = list(data)
data = data[0]
batches = batchify(data, batch_size)
我得到了
torch.Size([100, 6])
现在,我想从字符串加载数据,所以我写了
def load_text(text, vocab, direction):
buf = io.StringIO(text)
lines = buf.readlines()
data = list(''.join(lines))
idx = vocab['char'].map(data)
if direction == 'backward': idx = idx[::-1]
yield torch.tensor(idx)
但它并没有像预期的那样起作用:
data = load_text(text, vocab, direction)
if isinstance(data, GeneratorType):
data = list(data)
data = data[0]
batches = batchify(data, batch_size)
print(batches, batches.size())
我得到一个空的
tensor
:
tensor([], size=(100, 0), dtype=torch.int64) torch.Size([100, 0])