forked from mindspore-Ecosystem/mindspore
tag env bugfix
This commit is contained in:
parent
5048d3823b
commit
ca8aba5c29
|
@ -29,7 +29,7 @@ const std::vector<size_t> &EnvCreateKernel::GetWorkspaceSizeList() const { retur
|
|||
bool EnvCreateKernel::Init(const CNodePtr &cnode) {
|
||||
const auto &name = AnfAlgo::GetNodeAttr<std::string>(cnode, kEnvTypeName);
|
||||
std::tie(handle_, env_) = EnvironmentFactory::GetInstance().Create(name);
|
||||
MS_LOG(EXCEPTION) << "Create environment " << name << " failed.";
|
||||
MS_EXCEPTION_IF_NULL(env_);
|
||||
env_->Init(cnode, nullptr);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
|
@ -53,7 +53,7 @@ const std::vector<size_t> &EnvResetKernel::GetWorkspaceSizeList() const { return
|
|||
bool EnvResetKernel::Init(const CNodePtr &cnode) {
|
||||
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
|
||||
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
|
||||
MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed.";
|
||||
MS_EXCEPTION_IF_NULL(env_);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ const std::vector<size_t> &EnvStepKernel::GetWorkspaceSizeList() const { return
|
|||
bool EnvStepKernel::Init(const CNodePtr &cnode) {
|
||||
handle_ = AnfAlgo::GetNodeAttr<int64_t>(cnode, kHandleAttrName);
|
||||
env_ = EnvironmentFactory::GetInstance().GetByHandle(handle_);
|
||||
MS_LOG(EXCEPTION) << "Get environment handle " << handle_ << " failed.";
|
||||
MS_EXCEPTION_IF_NULL(env_);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -82,6 +82,7 @@ void EnvStepKernel::InitSizeLists() {
|
|||
output_size_list_.push_back(env_->StateSizeInBytes());
|
||||
output_size_list_.push_back(env_->RewardSizeInBytes());
|
||||
output_size_list_.push_back(env_->DoneSizeInBytes());
|
||||
workspace_size_list_.push_back(env_->WorkspaceSizeInBytes());
|
||||
}
|
||||
|
||||
bool EnvStepKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
|
|
|
@ -34,9 +34,9 @@ constexpr auto kMapLengthAttr = "map_length";
|
|||
constexpr auto kMapWidthAttr = "map_width";
|
||||
constexpr auto kWallHitPenaltyAttr = "wall_hit_penalty";
|
||||
constexpr auto kCatchRewardAttr = "catch_reward";
|
||||
constexpr auto kCaughtPenaltyAttr = "catched_penalty";
|
||||
constexpr auto kCaughtPenaltyAttr = "caught_penalty";
|
||||
constexpr auto kStepCostAttr = "step_cost";
|
||||
constexpr auto kEnvNumAttr = "env_num";
|
||||
constexpr auto kEnvNumAttr = "environment_num";
|
||||
} // namespace
|
||||
|
||||
TagEnvironment::~TagEnvironment() {
|
||||
|
@ -47,7 +47,7 @@ TagEnvironment::~TagEnvironment() {
|
|||
}
|
||||
|
||||
bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting_host) {
|
||||
MS_LOG(EXCEPTION) << "The `game_setting` should not be nullprt";
|
||||
MS_EXCEPTION_IF_NULL(setting_host);
|
||||
|
||||
setting_host->seed = AnfAlgo::GetNodeAttr<int64_t>(cnode, kSeedAttr);
|
||||
setting_host->predator_num = AnfAlgo::GetNodeAttr<int64_t>(cnode, kPredatorNumAttr);
|
||||
|
@ -66,7 +66,7 @@ bool TagEnvironment::InitGameSetting(const CNodePtr &cnode, GameSetting *setting
|
|||
}
|
||||
|
||||
bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *agent_state) {
|
||||
MS_LOG(EXCEPTION) << "The `state` should not be nullptr";
|
||||
MS_EXCEPTION_IF_NULL(agent_state);
|
||||
|
||||
int total_agents_num = env_num_ * agent_num_;
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
|
@ -81,7 +81,6 @@ bool TagEnvironment::InitAgentState(int predator_num, int prey_num, AgentState *
|
|||
}
|
||||
|
||||
bool TagEnvironment::FinalizeAgentState(const AgentState &agent_setting) {
|
||||
MS_LOG(EXCEPTION) << "The `state` should not be nullptr";
|
||||
auto &allocator = device::gpu::GPUMemoryAllocator::GetInstance();
|
||||
allocator.FreeTensorMem(agent_setting.prey_left);
|
||||
allocator.FreeTensorMem(agent_setting.time_step);
|
||||
|
@ -136,7 +135,7 @@ bool TagEnvironment::Step(const std::vector<AddressPtr> &inputs, const std::vect
|
|||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device_, action, state, reward, done,
|
||||
team_reward_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
team_reward, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -184,7 +183,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
|
|||
// Warmup
|
||||
StepBindBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done, stream);
|
||||
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
|
||||
team_reward_, stream);
|
||||
team_reward, stream);
|
||||
|
||||
// Collect profiling info
|
||||
device::gpu::CudaDeviceStream start = nullptr;
|
||||
|
@ -203,7 +202,7 @@ void TagEnvironment::StepKernelProfiling(const int *action, float *state, float
|
|||
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(start, stream), "Failed to record event to stream.");
|
||||
StepCrossBlock(env_num_, agent_num_, game_setting_device_, agent_state_device, action, state, reward, done,
|
||||
team_reward_, stream);
|
||||
team_reward, stream);
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::RecordEvent(end, stream), "Failed to record event to stream.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(start), "Failed to sync event.");
|
||||
CHECK_OP_RET_WITH_EXCEPT(device::gpu::CudaDriver::SyncEvent(end), "Failed to sync event.");
|
||||
|
|
|
@ -60,8 +60,7 @@ class TagEnvironment : public Environment {
|
|||
GameSetting game_setting_host_;
|
||||
GameSetting *game_setting_device_ = nullptr;
|
||||
AgentState agent_state_host_;
|
||||
AgentState *agent_state_device_;
|
||||
float *team_reward_ = nullptr;
|
||||
AgentState *agent_state_device_ = nullptr;
|
||||
|
||||
enum StepKernelType { kBindBlock = 0, kCrossBlock };
|
||||
void StepKernelProfiling(const int *action, float *state, float *reward, bool *done, float *team_reward,
|
||||
|
|
Loading…
Reference in New Issue