代码之家  ›  专栏  ›  技术社区  ›  Irene Ferfoglia

ImportError:无法从“ray.air”导入名称“Checkpoint”

  •  0
  • Irene Ferfoglia  · 技术社区  · 2 年前

    我正试着遵循这个 tutorial 要使用Ray调整PyTorch中的超参数,请复制粘贴所有内容,但我收到以下错误:

    ImportError: cannot import name 'Checkpoint' from 'ray.air'
    

    从这个进口行:

    from ray.air import Checkpoint
    

    我使用安装了ray pip install -U "ray[tune]" 正如官方网站上所建议的那样。在得到错误后,可以肯定的是,我也尝试了一个更通用的 pip install ray ,这并没有解决任何问题。
    我有版本 ray==2.9.0 安装。

    有什么帮助吗?

    1 回复  |  直到 2 年前
        1
  •  0
  •   حمزة نبيل    2 年前

    尝试安装旧版本 2.7.0 :

    pip install ray[tune]==2.7.0
    

    更新:

    对于最新版本,Ray AIR会话将替换为Ray Train上下文对象。 您可以导入 Checkpoint 使用:

    from ray.train import Checkpoint
    

    您需要调整您的代码,如下所示:

    from ray import air, train
    
    # Ray Train methods and classes:
    air.session.report               -> train.report
    air.session.get_dataset_shard    -> train.get_dataset_shard
    air.session.get_checkpoint       -> train.get_checkpoint
    air.Checkpoint                   -> train.Checkpoint
    air.Result                       -> train.Result
    
    # Ray Train configurations:
    air.config.CheckpointConfig      -> train.CheckpointConfig
    air.config.FailureConfig         -> train.FailureConfig
    air.config.RunConfig             -> train.RunConfig
    air.config.ScalingConfig         -> train.ScalingConfig
    
    # Ray TrainContext methods:
    air.session.get_experiment_name  -> train.get_context().get_experiment_name
    air.session.get_trial_name       -> train.get_context().get_trial_name
    air.session.get_trial_id         -> train.get_context().get_trial_id
    air.session.get_trial_resources  -> train.get_context().get_trial_resources
    air.session.get_trial_dir        -> train.get_context().get_trial_dir
    air.session.get_world_size       -> train.get_context().get_world_size
    air.session.get_world_rank       -> train.get_context().get_world_rank
    air.session.get_local_rank       -> train.get_context().get_local_rank
    air.session.get_local_world_size -> train.get_context().get_local_world_size
    air.session.get_node_rank        -> train.get_context().get_node_rank
    

    有关更多信息,请参阅: