!25355 environment factory

Merge pull request !25355 from chenweifeng/environment-register
This commit is contained in:
i-robot 2021-10-29 01:16:29 +00:00 committed by Gitee
commit bf54552228
2 changed files with 154 additions and 0 deletions

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/rl/environment_factory.h"
#include <utility>
namespace mindspore {
namespace kernel {
EnvironmentFactory &EnvironmentFactory::GetInstance() {
static EnvironmentFactory instance;
return instance;
}
std::tuple<int, std::shared_ptr<Environment>> EnvironmentFactory::Create(const std::string &name) {
auto env_iter = map_env_name_to_creators_.find(name);
if (env_iter == map_env_name_to_creators_.end()) {
std::ostringstream oss;
oss << "Environment " << name << " not exist.\n";
oss << "Environment registered list: [";
for (auto iter = map_env_name_to_creators_.begin(); iter != map_env_name_to_creators_.end(); iter++) {
oss << iter->first << " ";
}
oss << "]";
MS_LOG(EXCEPTION) << oss.str();
}
auto env = std::shared_ptr<Environment>(env_iter->second());
map_env_handle_to_instances_.insert(std::make_pair(++handle_, env));
return std::make_tuple(handle_, env);
}
void EnvironmentFactory::Delete(int64_t handle) { map_env_handle_to_instances_.erase(handle); }
std::shared_ptr<Environment> EnvironmentFactory::GetByHandle(int64_t handle) {
auto iter = map_env_handle_to_instances_.find(handle);
if (iter == map_env_handle_to_instances_.end()) {
MS_LOG(EXCEPTION) << "Environment " << handle << " not exist.";
}
return iter->second;
}
void EnvironmentFactory::Register(const std::string &name, EnvCreator &&creator) {
map_env_name_to_creators_.insert(std::make_pair(name, creator));
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,95 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ENV_FACTORY_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ENV_FACTORY_H_
#include <memory>
#include <string>
#include <vector>
#include <tuple>
#include <map>
#include <utility>
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace kernel {
// Interface of MindSpore built-in environment.
// The derive class should override all virtual interface.
class Environment {
public:
// Initialize Environment.
virtual bool Init(const CNodePtr &cnode, void *stream_ptr) = 0;
// Reset Environment.
virtual bool Reset(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) = 0;
// Run one timestep.
virtual bool Step(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) = 0;
// Return Environment specification and the framework will malloc device memory for it.
virtual size_t ActionSizeInBytes() = 0;
virtual size_t StateSizeInBytes() = 0;
virtual size_t RewardSizeInBytes() = 0;
virtual size_t DoneSizeInBytes() = 0;
};
constexpr int64_t kInvalidHandle = -1;
using EnvCreator = std::function<Environment *()>;
// Class for environment registration, environment instances managerment with factory design pattern.
class EnvironmentFactory {
public:
// Create a factory instance with lazy mode.
static EnvironmentFactory &GetInstance();
// Create an environment instance with unique handle and instance returned.
std::tuple<int, std::shared_ptr<Environment>> Create(const std::string &name);
// Delete the environment instance.
void Delete(int64_t handle);
// Get environment instance by handle.
std::shared_ptr<Environment> GetByHandle(int64_t handle);
// Register environment creator.
void Register(const std::string &name, EnvCreator &&creator);
private:
EnvironmentFactory() = default;
~EnvironmentFactory() = default;
DISABLE_COPY_AND_ASSIGN(EnvironmentFactory)
int64_t handle_ = kInvalidHandle;
std::map<int64_t, std::shared_ptr<Environment>> map_env_handle_to_instances_;
std::map<std::string, EnvCreator> map_env_name_to_creators_;
};
// Class for environment registration.
// The registration depend on global variable initialization(constructor) which means that
// registration executed is before than `main()` and execution order is not guaranteed to be same.
class EnvironmentRegister {
public:
EnvironmentRegister(const std::string &name, EnvCreator &&creator) {
EnvironmentFactory::GetInstance().Register(name, std::move(creator));
}
};
// Helper macro for environment registration.
#define MS_REG_GPU_ENV(NAME, ENVCLASS) \
static_assert(std::is_base_of<Environment, ENVCLASS>::value, " must be base of Environment"); \
static const EnvironmentRegister g_##NAME##_gpu_env_reg(#NAME, []() { return new ENVCLASS(); });
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_ENV_FACTORY_H_