tag env bugfix

This commit is contained in:
wilfChen 2021-11-26 14:55:44 +08:00
parent 5048d3823b
commit ca8aba5c29
3 changed files with 12 additions and 13 deletions

View File

@ -29,7 +29,7 @@ const std::vector<size_t> &EnvCreateKernel::GetWorkspaceSizeList() const { retur
bool EnvCreateKernel::Init(const CNodePtr &cnode) {
const auto &name = AnfAlgo::GetNodeAttr<std::string>(cnode, kEnvTypeName);
std::tie(handle_, env_) = EnvironmentFactory::GetInstance().Create(name);
MS_LOG(EXCEPTION) << "Create environment " << name << " failed.";
MS_EXCEPTION_IF_NULL(env_);
env_->Init(cnode, nullptr);
InitSizeLists();
return true;
@ -53,7 +53,7 @@ const std::vector<size_t> &EnvResetKernel::GetWorkspaceSizeList() const { return
bool EnvResetKernel::Init(const CNodePtr &cnode) {
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed.";
MS_EXCEPTION_IF_NULL(env_);
InitSizeLists();
return true;
}
@ -72,7 +72,7 @@ const std::vector<size_t> &EnvStepKernel::GetWorkspaceSizeList() const { return
bool EnvStepKernel::Init(const CNodePtr &cnode) {
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed.";
MS_EXCEPTION_IF_NULL(env_);
InitSizeLists();
return true;
}
@ -82,6 +82,7 @@ void EnvStepKernel::InitSizeLists() {
output_size_list_.push_back(env_->StateSizeInBytes());
output_size_list_.push_back(env_->RewardSizeInBytes());
output_size_list_.push_back(env_->DoneSizeInBytes());
workspace_size_list_.push_back(env_->WorkspaceSizeInBytes());
}
bool EnvStepKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,

View File

@ -34,9 +34,9 @@ constexpr auto kMapLengthAttr = "map_length";
constexpr auto kMapWidthAttr = "map_width";
constexpr auto kWallHitPenaltyAttr = "wall_hit_penalty";
constexpr auto kCatchRewardAttr = "catch_reward";
constexpr auto kCaughtPenaltyAttr = "catched_penalty";
constexpr auto kCaughtPenaltyAttr = "caught_penalty";
constexpr auto kStepCostAttr = "step_cost";
constexpr auto kEnvNumAttr = "env_num";
constexpr auto kEnvNumAttr = "environment_num";
} // namespace
TagEnvironment::~TagEnvironment() {
@ -47,7 +47,7 @@ TagEnvironment::~TagEnvironment() {
}
bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting_host) {
MS_LOG(EXCEPTION) << "The `game_setting` should not be nullprt";
MS_EXCEPTION_IF_NULL(setting_host);
setting_host->seed = AnfAlgo::GetNodeAttr<int64_t>(cnode, kSeedAttr);
setting_host->predator_num = AnfAlgo::GetNodeAttr<int64_t>(cnode, kPredatorNumAttr);
@ -66,7 +66,7 @@ bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting
}
bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *agent_state) {
MS_LOG(EXCEPTION) << "The `state` should not be nullptr";
MS_EXCEPTION_IF_NULL(agent_state);
int total_agents_num = env_num_ * agent_num_;
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
@ -81,7 +81,6 @@ bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *
}
bool TagEnvironment::FinalizeAgentState(const AgentState &agent_setting) {
MS_LOG(EXCEPTION) << "The `state` should not be nullptr";
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
allocator.FreeTensorMem(agent_setting.prey_left);
allocator.FreeTensorMem(agent_setting.time_step);
@ -136,7 +135,7 @@ bool TagEnvironment::Step(const std::vector<AddressPtr> &inputs, const std::vect
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device_, action, state, reward, done,
team_reward_, reinterpret_cast<cudaStream_t>(stream_ptr));
team_reward, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
@ -184,7 +183,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
// Warmup
StepBindBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, stream);
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
team_reward_, stream);
team_reward, stream);
// Collect profiling info
device::gpu::CudaDeviceStream start = nullptr;
@ -203,7 +202,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(start, stream), "Failed to record event to stream.");
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
team_reward_, stream);
team_reward, stream);
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(end, stream), "Failed to record event to stream.");
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(start), "Failed to sync event.");
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(end), "Failed to sync event.");

View File

@ -60,8 +60,7 @@ class TagEnvironment : public Environment {
GameSetting game_setting_host_;
GameSetting *game_setting_device_ = nullptr;
AgentState agent_state_host_;
AgentState *agent_state_device_;
float *team_reward_ = nullptr;
AgentState *agent_state_device_ = nullptr;
enum StepKernelType { kBindBlock = 0, kCrossBlock };
void StepKernelProfiling(const int *action, float *state, float *reward, bool *done, float *team_reward,