diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/env_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/env_gpu_kernel.cc index ddad86a4771..872616180bb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/env_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/env_gpu_kernel.cc @@ -29,7 +29,7 @@ const std::vector &EnvCreateKernel::GetWorkspaceSizeList() const { retur bool EnvCreateKernel::Init(const CNodePtr &cnode) { const auto &name = AnfAlgo::GetNodeAttr(cnode, kEnvTypeName); std::tie(handle_, env_) = EnvironmentFactory::GetInstance().Create(name); - MS_LOG(EXCEPTION) << "Create environment " << name << " failed."; + MS_EXCEPTION_IF_NULL(env_); env_->Init(cnode, nullptr); InitSizeLists(); return true; @@ -53,7 +53,7 @@ const std::vector &EnvResetKernel::GetWorkspaceSizeList() const { return bool EnvResetKernel::Init(const CNodePtr &cnode) { handle_ = AnfAlgo::GetNodeAttr(cnode, kHandleAttrName); env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_); - MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed."; + MS_EXCEPTION_IF_NULL(env_); InitSizeLists(); return true; } @@ -72,7 +72,7 @@ const std::vector &EnvStepKernel::GetWorkspaceSizeList() const { return bool EnvStepKernel::Init(const CNodePtr &cnode) { handle_ = AnfAlgo::GetNodeAttr(cnode, kHandleAttrName); env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_); - MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed."; + MS_EXCEPTION_IF_NULL(env_); InitSizeLists(); return true; } @@ -82,6 +82,7 @@ void EnvStepKernel::InitSizeLists() { output_size_list_.push_back(env_->StateSizeInBytes()); output_size_list_.push_back(env_->RewardSizeInBytes()); output_size_list_.push_back(env_->DoneSizeInBytes()); + workspace_size_list_.push_back(env_->WorkspaceSizeInBytes()); } bool EnvStepKernel::Launch(const std::vector &inputs, const std::vector &workspace, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tag_environment.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tag_environment.cc index 952c2c78d27..2c4189c48a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tag_environment.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tag_environment.cc @@ -34,9 +34,9 @@ constexpr auto kMapLengthAttr = "map_length"; constexpr auto kMapWidthAttr = "map_width"; constexpr auto kWallHitPenaltyAttr = "wall_hit_penalty"; constexpr auto kCatchRewardAttr = "catch_reward"; -constexpr auto kCaughtPenaltyAttr = "catched_penalty"; +constexpr auto kCaughtPenaltyAttr = "caught_penalty"; constexpr auto kStepCostAttr = "step_cost"; -constexpr auto kEnvNumAttr = "env_num"; +constexpr auto kEnvNumAttr = "environment_num"; } // namespace TagEnvironment::~TagEnvironment() { @@ -47,7 +47,7 @@ TagEnvironment::~TagEnvironment() { } bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting_host) { - MS_LOG(EXCEPTION) << "The `game_setting` should not be nullprt"; + MS_EXCEPTION_IF_NULL(setting_host); setting_host->seed = AnfAlgo::GetNodeAttr(cnode, kSeedAttr); setting_host->predator_num = AnfAlgo::GetNodeAttr(cnode, kPredatorNumAttr); @@ -66,7 +66,7 @@ bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting } bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *agent_state) { - MS_LOG(EXCEPTION) << "The `state` should not be nullptr"; + MS_EXCEPTION_IF_NULL(agent_state); int total_agents_num = env_num_ * agent_num_; auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance(); @@ -81,7 +81,6 @@ bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState * } bool TagEnvironment::FinalizeAgentState(const AgentState &agent_setting) { - MS_LOG(EXCEPTION) << "The `state` should not be nullptr"; auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance(); allocator.FreeTensorMem(agent_setting.prey_left); allocator.FreeTensorMem(agent_setting.time_step); @@ -136,7 +135,7 @@ bool TagEnvironment::Step(const std::vector &inputs, const std::vect reinterpret_cast(stream_ptr)); } else { StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device_, action, state, reward, done, - team_reward_, reinterpret_cast(stream_ptr)); + team_reward, reinterpret_cast(stream_ptr)); } return true; } @@ -184,7 +183,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float // Warmup StepBindBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, stream); StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, - team_reward_, stream); + team_reward, stream); // Collect profiling info device::gpu::CudaDeviceStream start = nullptr; @@ -203,7 +202,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(start, stream), "Failed to record event to stream."); StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, - team_reward_, stream); + team_reward, stream); CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(end, stream), "Failed to record event to stream."); CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(start), "Failed to sync event."); CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(end), "Failed to sync event."); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tag_environment.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tag_environment.h index a32905b12ac..fbda5864e7a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tag_environment.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/rl/tag_environment.h @@ -60,8 +60,7 @@ class TagEnvironment : public Environment { GameSetting game_setting_host_; GameSetting *game_setting_device_ = nullptr; AgentState agent_state_host_; - AgentState *agent_state_device_; - float *team_reward_ = nullptr; + AgentState *agent_state_device_ = nullptr; enum StepKernelType { kBindBlock = 0, kCrossBlock }; void StepKernelProfiling(const int *action, float *state, float *reward, bool *done, float *team_reward,