forked from mindspore-Ecosystem/mindspore
!26837 tag environment bugfix
Merge pull request !26837 from chenweifeng/tag-environment-bug-fix
This commit is contained in:
commit
c1798df274
|
@ -29,7 +29,7 @@ const std::vector<size_t> &EnvCreateKernel::GetWorkspaceSizeList() const { retur
|
||||||
bool EnvCreateKernel::Init(const CNodePtr &cnode) {
|
bool EnvCreateKernel::Init(const CNodePtr &cnode) {
|
||||||
const auto &name = AnfAlgo::GetNodeAttr<std::string>(cnode, kEnvTypeName);
|
const auto &name = AnfAlgo::GetNodeAttr<std::string>(cnode, kEnvTypeName);
|
||||||
std::tie(handle_, env_) = EnvironmentFactory::GetInstance().Create(name);
|
std::tie(handle_, env_) = EnvironmentFactory::GetInstance().Create(name);
|
||||||
MS_LOG(EXCEPTION) << "Create environment " << name << " failed.";
|
MS_EXCEPTION_IF_NULL(env_);
|
||||||
env_->Init(cnode, nullptr);
|
env_->Init(cnode, nullptr);
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
return true;
|
return true;
|
||||||
|
@ -53,7 +53,7 @@ const std::vector<size_t> &EnvResetKernel::GetWorkspaceSizeList() const { return
|
||||||
bool EnvResetKernel::Init(const CNodePtr &cnode) {
|
bool EnvResetKernel::Init(const CNodePtr &cnode) {
|
||||||
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
|
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
|
||||||
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
|
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
|
||||||
MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed.";
|
MS_EXCEPTION_IF_NULL(env_);
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ const std::vector<size_t> &EnvStepKernel::GetWorkspaceSizeList() const { return
|
||||||
bool EnvStepKernel::Init(const CNodePtr &cnode) {
|
bool EnvStepKernel::Init(const CNodePtr &cnode) {
|
||||||
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
|
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
|
||||||
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
|
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
|
||||||
MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed.";
|
MS_EXCEPTION_IF_NULL(env_);
|
||||||
InitSizeLists();
|
InitSizeLists();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -82,6 +82,7 @@ void EnvStepKernel::InitSizeLists() {
|
||||||
output_size_list_.push_back(env_->StateSizeInBytes());
|
output_size_list_.push_back(env_->StateSizeInBytes());
|
||||||
output_size_list_.push_back(env_->RewardSizeInBytes());
|
output_size_list_.push_back(env_->RewardSizeInBytes());
|
||||||
output_size_list_.push_back(env_->DoneSizeInBytes());
|
output_size_list_.push_back(env_->DoneSizeInBytes());
|
||||||
|
workspace_size_list_.push_back(env_->WorkspaceSizeInBytes());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EnvStepKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool EnvStepKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
|
|
@ -34,9 +34,9 @@ constexpr auto kMapLengthAttr = "map_length";
|
||||||
constexpr auto kMapWidthAttr = "map_width";
|
constexpr auto kMapWidthAttr = "map_width";
|
||||||
constexpr auto kWallHitPenaltyAttr = "wall_hit_penalty";
|
constexpr auto kWallHitPenaltyAttr = "wall_hit_penalty";
|
||||||
constexpr auto kCatchRewardAttr = "catch_reward";
|
constexpr auto kCatchRewardAttr = "catch_reward";
|
||||||
constexpr auto kCaughtPenaltyAttr = "catched_penalty";
|
constexpr auto kCaughtPenaltyAttr = "caught_penalty";
|
||||||
constexpr auto kStepCostAttr = "step_cost";
|
constexpr auto kStepCostAttr = "step_cost";
|
||||||
constexpr auto kEnvNumAttr = "env_num";
|
constexpr auto kEnvNumAttr = "environment_num";
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
TagEnvironment::~TagEnvironment() {
|
TagEnvironment::~TagEnvironment() {
|
||||||
|
@ -47,7 +47,7 @@ TagEnvironment::~TagEnvironment() {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting_host) {
|
bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting_host) {
|
||||||
MS_LOG(EXCEPTION) << "The `game_setting` should not be nullprt";
|
MS_EXCEPTION_IF_NULL(setting_host);
|
||||||
|
|
||||||
setting_host->seed = AnfAlgo::GetNodeAttr<int64_t>(cnode, kSeedAttr);
|
setting_host->seed = AnfAlgo::GetNodeAttr<int64_t>(cnode, kSeedAttr);
|
||||||
setting_host->predator_num = AnfAlgo::GetNodeAttr<int64_t>(cnode, kPredatorNumAttr);
|
setting_host->predator_num = AnfAlgo::GetNodeAttr<int64_t>(cnode, kPredatorNumAttr);
|
||||||
|
@ -66,7 +66,7 @@ bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *agent_state) {
|
bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *agent_state) {
|
||||||
MS_LOG(EXCEPTION) << "The `state` should not be nullptr";
|
MS_EXCEPTION_IF_NULL(agent_state);
|
||||||
|
|
||||||
int total_agents_num = env_num_ * agent_num_;
|
int total_agents_num = env_num_ * agent_num_;
|
||||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||||
|
@ -81,7 +81,6 @@ bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TagEnvironment::FinalizeAgentState(const AgentState &agent_setting) {
|
bool TagEnvironment::FinalizeAgentState(const AgentState &agent_setting) {
|
||||||
MS_LOG(EXCEPTION) << "The `state` should not be nullptr";
|
|
||||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||||
allocator.FreeTensorMem(agent_setting.prey_left);
|
allocator.FreeTensorMem(agent_setting.prey_left);
|
||||||
allocator.FreeTensorMem(agent_setting.time_step);
|
allocator.FreeTensorMem(agent_setting.time_step);
|
||||||
|
@ -136,7 +135,7 @@ bool TagEnvironment::Step(const std::vector<AddressPtr> &inputs, const std::vect
|
||||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
} else {
|
} else {
|
||||||
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device_, action, state, reward, done,
|
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device_, action, state, reward, done,
|
||||||
team_reward_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
team_reward, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -184,7 +183,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
|
||||||
// Warmup
|
// Warmup
|
||||||
StepBindBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, stream);
|
StepBindBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, stream);
|
||||||
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
|
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
|
||||||
team_reward_, stream);
|
team_reward, stream);
|
||||||
|
|
||||||
// Collect profiling info
|
// Collect profiling info
|
||||||
device::gpu::CudaDeviceStream start = nullptr;
|
device::gpu::CudaDeviceStream start = nullptr;
|
||||||
|
@ -203,7 +202,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
|
||||||
|
|
||||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(start, stream), "Failed to record event to stream.");
|
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(start, stream), "Failed to record event to stream.");
|
||||||
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
|
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
|
||||||
team_reward_, stream);
|
team_reward, stream);
|
||||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(end, stream), "Failed to record event to stream.");
|
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(end, stream), "Failed to record event to stream.");
|
||||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(start), "Failed to sync event.");
|
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(start), "Failed to sync event.");
|
||||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(end), "Failed to sync event.");
|
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(end), "Failed to sync event.");
|
||||||
|
|
|
@ -60,8 +60,7 @@ class TagEnvironment : public Environment {
|
||||||
GameSetting game_setting_host_;
|
GameSetting game_setting_host_;
|
||||||
GameSetting *game_setting_device_ = nullptr;
|
GameSetting *game_setting_device_ = nullptr;
|
||||||
AgentState agent_state_host_;
|
AgentState agent_state_host_;
|
||||||
AgentState *agent_state_device_;
|
AgentState *agent_state_device_ = nullptr;
|
||||||
float *team_reward_ = nullptr;
|
|
||||||
|
|
||||||
enum StepKernelType { kBindBlock = 0, kCrossBlock };
|
enum StepKernelType { kBindBlock = 0, kCrossBlock };
|
||||||
void StepKernelProfiling(const int *action, float *state, float *reward, bool *done, float *team_reward,
|
void StepKernelProfiling(const int *action, float *state, float *reward, bool *done, float *team_reward,
|
||||||
|
|
Loading…
Reference in New Issue