代码之家  ›  专栏  ›  技术社区  ›  lr100

要初始化的tf.data.Dataset.from_generator长

  •  0
  • lr100  · 技术社区  · 2 年前

    我有一个生成器,我正试图将它放入tf.data.dataset中。

    def static_syn_batch_generator(
            total_size: int, batch_size: int, start_random_seed:int=0, 
            fg_seeds_ss:SampleSet=None, bg_seeds_ss:SampleSet=None, target_level:str="Isotope"):
        
        static_syn = StaticSynthesizer(
            samples_per_seed = 10, # will be updated in generator
            snr_function ="log10",
            random_state = 0 # will be updated in generator
        )
        static_syn.random_state = start_random_seed
        samples_per_seed = math.ceil(batch_size/(len(fg_seeds_ss)*len(bg_seeds_ss)))
        # print(f"static_syn.samples_per_seed={static_syn.samples_per_seed}")
        # print(f"static_syn.random_state={static_syn.random_state}")
    
        counter = 0
        for i in range(total_size):
            # Regenerate for each batch
            if counter%batch_size == 0: # Regen data for every batch
                fg, bg, gross = static_syn.generate(fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
                fg_sources_cont_df = fg.sources.groupby(axis=1, level=target_level).sum()
                bg_sources_cont_df = bg.sources.groupby(axis=1, level=target_level).sum()
                gross_sources_cont_df = gross.sources.groupby(axis=1, level=target_level).sum()
                static_syn.random_state += 1
                print(static_syn.random_state)
                # print(f"static_syn.samples_per_seed={static_syn.samples_per_seed}")
                # print(f"static_syn.random_state={static_syn.random_state}")
    
            fg_X = fg.spectra.values[i%batch_size]
            fg_y = fg_sources_cont_df.values[i%batch_size].astype(float)
            bg_X = bg.spectra.values[i%batch_size]
            bg_y = bg_sources_cont_df.values[i%batch_size].astype(float)
            gross_X = gross.spectra.values[i%batch_size]
            gross_y = gross_sources_cont_df.values[i%batch_size].astype(float)
    
            
            yield (fg_X, fg_y), (bg_X, bg_y), (gross_X, gross_y)
            
    
            counter += 1
    

    当手动运行时,它可以工作,并需要6秒来输出和比较生成器的两个实例(以确保随机播种工作):

    total_size = 10
    batch_size = 2
    
    batch_gen = static_syn_batch_generator(total_size, batch_size, start_random_seed=0, fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
    fg0 = []
    bg0 =[]
    gross0 = []
    for i, ((fg_X, fg_y), (bg_X, bg_Y), (gross_X, gross_y)) in enumerate(batch_gen):
      fg0.append(fg_X)
      bg0.append(bg_X)
      gross0.append(gross_X)  
    
    print(f"len of fg0: {len(fg0)}")
    print(f"len of bg0: {len(bg0)}")
    print(f"len of gross0: {len(gross0)}")
    
    batch_gen = static_syn_batch_generator(total_size, batch_size, start_random_seed=0, fg_seeds_ss=fg_seeds_ss, bg_seeds_ss=bg_seeds_ss)
    fg1 = []
    bg1 =[]
    gross1 = []
    for i, ((fg_X, fg_y), (bg_X, bg_y), (gross_X, gross_y)) in enumerate(batch_gen):
      fg1.append(fg_X)
      bg1.append(bg_X)
      gross1.append(gross_X)  
    
    print(f"len of fg1: {len(fg1)}")
    print(f"len of bg1: {len(bg1)}")
    print(f"len of gross1: {len(gross1)}")
    
    
    assert np.array_equal(fg0, fg1)
    assert np.array_equal(bg0, bg1)
    assert np.array_equal(gross0, gross1)
    

    然而,当我尝试实例化一个tf.data.Dataset.from_generator时,它需要很长时间才能初始化(实际上不知道它是否在第15分钟完成)。

    fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
    
    total_samples = 10
    batch_size = 2
    start_random_seed = 0
    
    #TODO: TAKES FOREVER
    dataset = tf.data.Dataset.from_generator(
        generator=static_syn_batch_generator,
        args=(total_samples, batch_size, start_random_seed, fg_seeds_ss, bg_seeds_ss, "Isotope"),
        output_types=((tf.float32, tf.float32),(tf.float32, tf.float32),(tf.float32, tf.float32))
    )
    

    有人有什么建议或看到我做错了什么吗?

    0 回复  |  直到 2 年前
    推荐文章