!26837 tag environment bugfix

Merge pull request !26837 from chenweifeng/tag-environment-bug-fix
This commit is contained in:
i-robot 2021-11-27 07:44:27 +00:00 committed by Gitee
commit c1798df274
3 changed files with 12 additions and 13 deletions

View File

@ -29,7 +29,7 @@ const std::vector<size_t> &EnvCreateKernel::GetWorkspaceSizeList() const { retur
bool EnvCreateKernel::Init(const CNodePtr &cnode) { bool EnvCreateKernel::Init(const CNodePtr &cnode) {
const auto &name = AnfAlgo::GetNodeAttr<std::string>(cnode, kEnvTypeName); const auto &name = AnfAlgo::GetNodeAttr<std::string>(cnode, kEnvTypeName);
std::tie(handle_, env_) = EnvironmentFactory::GetInstance().Create(name); std::tie(handle_, env_) = EnvironmentFactory::GetInstance().Create(name);
MS_LOG(EXCEPTION) << "Create environment " << name << " failed."; MS_EXCEPTION_IF_NULL(env_);
env_->Init(cnode, nullptr); env_->Init(cnode, nullptr);
InitSizeLists(); InitSizeLists();
return true; return true;
@ -53,7 +53,7 @@ const std::vector<size_t> &EnvResetKernel::GetWorkspaceSizeList() const { return
bool EnvResetKernel::Init(const CNodePtr &cnode) { bool EnvResetKernel::Init(const CNodePtr &cnode) {
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName); handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_); env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed."; MS_EXCEPTION_IF_NULL(env_);
InitSizeLists(); InitSizeLists();
return true; return true;
} }
@ -72,7 +72,7 @@ const std::vector<size_t> &EnvStepKernel::GetWorkspaceSizeList() const { return
bool EnvStepKernel::Init(const CNodePtr &cnode) { bool EnvStepKernel::Init(const CNodePtr &cnode) {
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName); handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_); env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed."; MS_EXCEPTION_IF_NULL(env_);
InitSizeLists(); InitSizeLists();
return true; return true;
} }
@ -82,6 +82,7 @@ void EnvStepKernel::InitSizeLists() {
output_size_list_.push_back(env_->StateSizeInBytes()); output_size_list_.push_back(env_->StateSizeInBytes());
output_size_list_.push_back(env_->RewardSizeInBytes()); output_size_list_.push_back(env_->RewardSizeInBytes());
output_size_list_.push_back(env_->DoneSizeInBytes()); output_size_list_.push_back(env_->DoneSizeInBytes());
workspace_size_list_.push_back(env_->WorkspaceSizeInBytes());
} }
bool EnvStepKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool EnvStepKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,

View File

@ -34,9 +34,9 @@ constexpr auto kMapLengthAttr = "map_length";
constexpr auto kMapWidthAttr = "map_width"; constexpr auto kMapWidthAttr = "map_width";
constexpr auto kWallHitPenaltyAttr = "wall_hit_penalty"; constexpr auto kWallHitPenaltyAttr = "wall_hit_penalty";
constexpr auto kCatchRewardAttr = "catch_reward"; constexpr auto kCatchRewardAttr = "catch_reward";
constexpr auto kCaughtPenaltyAttr = "catched_penalty"; constexpr auto kCaughtPenaltyAttr = "caught_penalty";
constexpr auto kStepCostAttr = "step_cost"; constexpr auto kStepCostAttr = "step_cost";
constexpr auto kEnvNumAttr = "env_num"; constexpr auto kEnvNumAttr = "environment_num";
} // namespace } // namespace
TagEnvironment::~TagEnvironment() { TagEnvironment::~TagEnvironment() {
@ -47,7 +47,7 @@ TagEnvironment::~TagEnvironment() {
} }
bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting_host) { bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting_host) {
MS_LOG(EXCEPTION) << "The `game_setting` should not be nullprt"; MS_EXCEPTION_IF_NULL(setting_host);
setting_host->seed = AnfAlgo::GetNodeAttr<int64_t>(cnode, kSeedAttr); setting_host->seed = AnfAlgo::GetNodeAttr<int64_t>(cnode, kSeedAttr);
setting_host->predator_num = AnfAlgo::GetNodeAttr<int64_t>(cnode, kPredatorNumAttr); setting_host->predator_num = AnfAlgo::GetNodeAttr<int64_t>(cnode, kPredatorNumAttr);
@ -66,7 +66,7 @@ bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting
} }
bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *agent_state) { bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *agent_state) {
MS_LOG(EXCEPTION) << "The `state` should not be nullptr"; MS_EXCEPTION_IF_NULL(agent_state);
int total_agents_num = env_num_ * agent_num_; int total_agents_num = env_num_ * agent_num_;
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance(); auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
@ -81,7 +81,6 @@ bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *
} }
bool TagEnvironment::FinalizeAgentState(const AgentState &agent_setting) { bool TagEnvironment::FinalizeAgentState(const AgentState &agent_setting) {
MS_LOG(EXCEPTION) << "The `state` should not be nullptr";
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance(); auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
allocator.FreeTensorMem(agent_setting.prey_left); allocator.FreeTensorMem(agent_setting.prey_left);
allocator.FreeTensorMem(agent_setting.time_step); allocator.FreeTensorMem(agent_setting.time_step);
@ -136,7 +135,7 @@ bool TagEnvironment::Step(const std::vector<AddressPtr> &inputs, const std::vect
reinterpret_cast<cudaStream_t>(stream_ptr)); reinterpret_cast<cudaStream_t>(stream_ptr));
} else { } else {
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device_, action, state, reward, done, StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device_, action, state, reward, done,
team_reward_, reinterpret_cast<cudaStream_t>(stream_ptr)); team_reward, reinterpret_cast<cudaStream_t>(stream_ptr));
} }
return true; return true;
} }
@ -184,7 +183,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
// Warmup // Warmup
StepBindBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, stream); StepBindBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, stream);
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
team_reward_, stream); team_reward, stream);
// Collect profiling info // Collect profiling info
device::gpu::CudaDeviceStream start = nullptr; device::gpu::CudaDeviceStream start = nullptr;
@ -203,7 +202,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(start, stream), "Failed to record event to stream."); CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(start, stream), "Failed to record event to stream.");
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
team_reward_, stream); team_reward, stream);
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(end, stream), "Failed to record event to stream."); CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(end, stream), "Failed to record event to stream.");
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(start), "Failed to sync event."); CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(start), "Failed to sync event.");
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(end), "Failed to sync event."); CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(end), "Failed to sync event.");

View File

@ -60,8 +60,7 @@ class TagEnvironment : public Environment {
GameSetting game_setting_host_; GameSetting game_setting_host_;
GameSetting *game_setting_device_ = nullptr; GameSetting *game_setting_device_ = nullptr;
AgentState agent_state_host_; AgentState agent_state_host_;
AgentState *agent_state_device_; AgentState *agent_state_device_ = nullptr;
float *team_reward_ = nullptr;
enum StepKernelType { kBindBlock = 0, kCrossBlock }; enum StepKernelType { kBindBlock = 0, kCrossBlock };
void StepKernelProfiling(const int *action, float *state, float *reward, bool *done, float *team_reward, void StepKernelProfiling(const int *action, float *state, float *reward, bool *done, float *team_reward,