代码之家  ›  专栏  ›  技术社区  ›  ARAT

从Python生成器中重复tf.data.Dataset.from_generator(),迭代数据库

  •  0
  • ARAT  · 技术社区  · 6 年前

    我有一个面板数据集,我想做长期短期记忆(LSTM)它。数据集来自postgreSQL数据库。我的数据结构如下:

    enter image description here

    因此,我的时间步长是4。它是多对多LSTM,我的输入和输出都是序列。输入将具有形状 [Batch_size, 4, 23] 产出会有形状 [Batch_size, 4, 2] (我是一个热编码)。

    我正在使用Python生成器获取行。我有很多( fetchmany )其中记录的个数是4,因为它对应于一个特定的人。

    class it_try:
        import passwords_and_paths
    
        import psycopg2
    
        def __init__(self, sql, number_of_records):
            self.sql = sql
            self.number_of_records = number_of_records
            self.pgConnectString = "host='/var/run/postgresql' port='{}' dbname='{}' user='{}' password='{}'".format(it_try.passwords_and_paths.database['port'],
                                                                                                                   it_try.passwords_and_paths.database['name'],
                                                                                                                   it_try.passwords_and_paths.database['user'],
                                                                                                                   it_try.passwords_and_paths.database['pass'])
        self.pgConnection=psycopg2.connect(self.pgConnectString)
        self.pgCursor = self.pgConnection.cursor(name='fetch_large_result')
        self.pgCursor.execute(self.sql)
    
        def __iter__(self):
            return self
    
        def __next__(self):
            row = self.pgCursor.fetchmany(self.number_of_records)
            current_obs = []
            for i in row:
                current_obs.append(i)
    
            features = np.array(current_obs)[:,3:26]
            labels = np.array(current_obs)[:,-1].astype(int)
    
            return features, labels
    
        def __del__(self):
            self.pgCursor.close()
    

    特征的形状是 [4,23] 标签的形状是 [4,] . 然后我使用 tf.data.Dataset.from_generator() 张量流函数。形状和数据类型的定义是正确的,我在这里对标签进行了热编码,每次调用3个人。

    generator = it_try(sql = 'SELECT * FROM public.basetable order by year, customer_id, quarter', number_of_records = 4)
    train_dataset = tf.data.Dataset.from_generator(lambda: generator, (tf.float32, tf.int32), (tf.TensorShape([4,23]), tf.TensorShape([4,])))
    train_dataset=train_dataset.map(lambda *x:(x[0], tf.cast(tf.one_hot(x[1],2),tf.int32)))
    train_dataset = train_dataset.batch(3)
    

    输出是 <BatchDataset shapes: ((?, 4, 23), (?, 4, 2)), types: (tf.float32, tf.int32)> . 到现在为止,一直都还不错。

    我创建迭代器并初始化它,我可以成功地打印批(本例中是两个批)。

    iterator = tf.data.Iterator.from_structure(train_dataset.output_types, 
    train_dataset.output_shapes)
    X, y = iterator.get_next()
    training_init_op = iterator.make_initializer(train_dataset)
    
    with tf.Session() as sess:
        sess.run(training_init_op)
        for batch in range(2):
            print(sess.run([X,y]))
    

    然而,当我想多次传递训练数据时(在本例中,epoch的数量是2),我得到一个错误,当然,这是因为我不能重置Python和Tensorflow迭代器。

    with tf.Session() as sess:
        for epoch in range(2):
            sess.run(training_init_op)
            for batch in range(2):
                print(sess.run([X,y]))
    

    它可以很好地打印第一个历元,但当它是第二个历元时,我会得到错误。

    ---------------------------------------------------------------------------
    UnknownError                              Traceback (most recent call last)
    /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
       1321     try:
    -> 1322       return fn(*args)
       1323     except errors.OpError as e:
    
    /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
       1306       return self._call_tf_sessionrun(
    -> 1307           options, feed_dict, fetch_list, target_list, run_metadata)
       1308 
    
    /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
       1408           self._session, options, feed_dict, fetch_list, target_list,
    -> 1409           run_metadata)
       1410     else:
    
    UnknownError: IndexError: too many indices for array
    Traceback (most recent call last):
    
      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/script_ops.py", line 158, in __call__
        ret = func(*args)
    
      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 410, in generator_py_func
        values = next(generator_state.get_iterator(iterator_id))
    
      File "<ipython-input-64-e6c5163f3adc>", line 26, in __next__
        features = np.array(current_obs)[:,3:26]
    
    IndexError: too many indices for array
    
    
         [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_INT32], token="pyfunc_46"](arg0)]]
         [[Node: IteratorGetNext_23 = IteratorGetNext[output_shapes=[[?,4,23], [?,4,2]], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_23)]]
    
    During handling of the above exception, another exception occurred:
    
    UnknownError                              Traceback (most recent call last)
    <ipython-input-67-213eeaa1c283> in <module>()
          7         sess.run(training_init_op)
          8         for i in range(2):
    ----> 9             print(sess.run([X,y]))
    
    /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
        898     try:
        899       result = self._run(None, fetches, feed_dict, options_ptr,
    --> 900                          run_metadata_ptr)
        901       if run_metadata:
        902         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
    
    /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
       1133     if final_fetches or final_targets or (handle and feed_dict_tensor):
       1134       results = self._do_run(handle, final_targets, final_fetches,
    -> 1135                              feed_dict_tensor, options, run_metadata)
       1136     else:
       1137       results = []
    
    /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
       1314     if handle is None:
       1315       return self._do_call(_run_fn, feeds, fetches, targets, options,
    -> 1316                            run_metadata)
       1317     else:
       1318       return self._do_call(_prun_fn, handle, feeds, fetches)
    
    /usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
       1333         except KeyError:
       1334           pass
    -> 1335       raise type(e)(node_def, op, message)
       1336 
       1337   def _extend_graph(self):
    
    UnknownError: IndexError: too many indices for array
    Traceback (most recent call last):
    
      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/script_ops.py", line 158, in __call__
        ret = func(*args)
    
      File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 410, in generator_py_func
        values = next(generator_state.get_iterator(iterator_id))
    
      File "<ipython-input-64-e6c5163f3adc>", line 26, in __next__
        features = np.array(current_obs)[:,3:26]
    
    IndexError: too many indices for array
    
    
         [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_INT32], token="pyfunc_46"](arg0)]]
         [[Node: IteratorGetNext_23 = IteratorGetNext[output_shapes=[[?,4,23], [?,4,2]], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_23)]]
    

    我试过了 .repeat(2) 无济于事。

    有人能帮我吗?当我使用python迭代器(数据来自数据库)时,如何运行epochs?

    0 回复  |  直到 6 年前