!31529 priority replay buffer aicpu library select

Merge pull request !31529 from chenweifeng/aicpu-priority-replay-buffer
This commit is contained in:
i-robot 2022-03-21 06:32:13 +00:00 committed by Gitee
commit 1dba54470c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 34 additions and 17 deletions

View File

@ -350,6 +350,10 @@ constexpr auto kEnvironSetOpName = "EnvironSet";
constexpr auto kEnvironGetOpName = "EnvironGet";
constexpr auto kEnvironDestroyAllOpName = "EnvironDestroyAll";
constexpr auto kUpdateStateOpName = "UpdateState";
constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
constexpr auto kPriorityReplayBufferUpdate = "PriorityReplayBufferUpdate";
// Communication world group
constexpr auto kNcclWorldGroup = "nccl_world_group";

View File

@ -75,14 +75,26 @@ constexpr auto kEnvironCreate = "EnvironCreate";
constexpr auto kEnvironSet = "EnvironSet";
constexpr auto kEnvironGet = "EnvironGet";
constexpr auto kEnvironDestroyAll = "EnvironDestroyAll";
constexpr auto kPriorityReplayBufferCreate = "PriorityReplayBufferCreate";
constexpr auto kPriorityReplayBufferPush = "PriorityReplayBufferPush";
constexpr auto kPriorityReplayBufferSample = "PriorityReplayBufferSample";
constexpr auto kPriorityReplayBufferUpdate = "PriorityReplayBufferUpdate";
const std::set<std::string> kCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch,
kSearchSorted, kResizeBilinear, kResizeBilinearGrad, kScatterElements};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3D,
kDropout2D, kNonMaxSuppressionV3, kGetNext, kInitData, kPrint};
const std::set<std::string> kCpuKernelBaseOps{kRandomChoiceWithMask, kEnvironCreate, kEnvironSet, kEnvironGet,
kEnvironDestroyAll};
const std::set<std::string> kCpuKernelBaseOps{kRandomChoiceWithMask,
kEnvironCreate,
kEnvironSet,
kEnvironGet,
kEnvironDestroyAll,
kPriorityReplayBufferCreate,
kPriorityReplayBufferPush,
kPriorityReplayBufferSample,
kPriorityReplayBufferUpdate};
const std::set<std::string> kDynamicInputOps{
kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName, kStackPushOpName, kStackPopOpName, kDynamicStitch};
kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName,
kStackPushOpName, kStackPopOpName, kDynamicStitch, kPriorityReplayBufferPush, kPriorityReplayBufferSample};
struct AicpuParamHead {
uint32_t length; // Total length: include cunstom message
uint32_t ioAddrNum; // Input and output address number

View File

@ -139,7 +139,7 @@
#include "plugin/device/ascend/optimizer/enhancer/add_attr_for_3d_graph.h"
#include "plugin/device/ascend/optimizer/enhancer/split_n_optimizer.h"
#include "plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.h"
#include "plugin/device/ascend/optimizer/mindir/env_op_attr_update.h"
#include "plugin/device/ascend/optimizer/mindir/aicpu_lib_select.h"
#include "plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.h"
#include "plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.h"
#include "plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.h"
@ -575,7 +575,7 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &graph) {
auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
unify_mindir_pm->AddPass(std::make_shared<opt::SpaceToBatchNDAttrUpdate>());
unify_mindir_pm->AddPass(std::make_shared<opt::BatchToSpaceNDAttrUpdate>());
unify_mindir_pm->AddPass(std::make_shared<opt::EnvOpAttrUpdate>());
unify_mindir_pm->AddPass(std::make_shared<opt::AICpuLibSelectPass>());
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPool2MaxPoolWithArgmax>());
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolWithArgmaxUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::MaxPoolGradWithArgmaxUnifyMindIR>());

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/mindir/env_op_attr_update.h"
#include "plugin/device/ascend/optimizer/mindir/aicpu_lib_select.h"
#include <set>
#include <string>
#include "include/common/utils/utils.h"
@ -23,21 +23,22 @@
namespace mindspore {
namespace opt {
const AnfNodePtr EnvOpAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
static const std::set<std::string> kEnvOpNames = {kEnvironCreateOpName, kEnvironSetOpName, kEnvironGetOpName,
kEnvironDestroyAllOpName};
static const std::set<std::string> kAICpuOpNames = {
kEnvironCreateOpName, kEnvironSetOpName, kEnvironGetOpName, kEnvironDestroyAllOpName,
kPriorityReplayBufferCreate, kPriorityReplayBufferPush, kPriorityReplayBufferSample, kPriorityReplayBufferUpdate};
static const std::string kEnvOpSoNames = "mindspore_aicpu_kernels";
if (!node->isa<CNode>()) {
return node;
}
auto kernel_name = common::AnfAlgo::GetCNodeName(node);
if (kEnvOpNames.find(kernel_name) != kEnvOpNames.end()) {
if (kAICpuOpNames.find(kernel_name) != kAICpuOpNames.end()) {
common::AnfAlgo::SetNodeAttr(kAttrCustAicpu, MakeValue(kEnvOpSoNames), node);
}

View File

@ -13,19 +13,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ENV_OP_ATTR_UPDATE_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ENV_OP_ATTR_UPDATE_H_
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AICPU_LIB_SELECT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AICPU_LIB_SELECT_H_
#include "backend/common/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
class EnvOpAttrUpdate : public PatternProcessPass {
class AICpuLibSelectPass : public PatternProcessPass {
public:
explicit EnvOpAttrUpdate(bool multigraph = true) : PatternProcessPass("env_op_attr_update", multigraph) {}
~EnvOpAttrUpdate() override = default;
explicit AICpuLibSelectPass(bool multigraph = true) : PatternProcessPass("env_op_attr_update", multigraph) {}
~AICpuLibSelectPass() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ENV_OP_ATTR_UPDATE_H_
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AICPU_LIB_SELECT_H_