!40668 move ascend const input to attr pass to unify mindir

Merge pull request !40668 from laiyongqiang/new_pass
This commit is contained in:
i-robot 2022-08-29 01:25:58 +00:00 committed by Gitee
commit ccf3db786c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 443 additions and 19 deletions

View File

@ -20,6 +20,7 @@
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
"mindspore/mindspore/ccsrc/backend/common/optimizer/const_input_to_attr_factory.h" "runtime/explicit"
# Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"

View File

@ -129,7 +129,7 @@ class BACKEND_EXPORT ConvertOpInfoRegister {
class RegisterHelper {
public:
RegisterHelper(const std::string &name, const std::string &device_name, bool is_dynamic_shape, int len, ...);
explicit RegisterHelper(const ConvertOpInfo &convert_op_info);
RegisterHelper(const ConvertOpInfo &convert_op_info);
~RegisterHelper() = default;
private:

View File

@ -42,6 +42,12 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
<< ", primitive_target: " << primitive_target;
backend = primitive_target;
}
// Ascend const input to attr move to AscendConvertConstInputToAttr
if (backend == kAscendDevice) {
return nullptr;
}
auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(node);
mindspore::HashSet<size_t> input_to_attr = {};
auto reg_info = opt::ConvertOpInfoRegister::GetInstance().GetConvertOpInfo(name, backend, is_dynamic_shape);

View File

@ -30,8 +30,6 @@
namespace mindspore {
namespace pynative {
namespace PyNativeAlgo {
const std::set<std::string> kAxisNone = {"ReduceSum"};
std::string Common::GetIdByValue(const ValuePtr &v) {
MS_EXCEPTION_IF_NULL(v);
if (v->isa<tensor::Tensor>()) {
@ -398,14 +396,10 @@ void DataConvert::ConvertTupleValueToTensor(const FrontendOpRunInfoPtr &op_run_i
const auto &tuple_inputs = value_seq->value();
if (tuple_inputs.empty()) {
if (kAxisNone.find(op_prim->name()) != kAxisNone.end()) {
(void)op_run_info->base_op_run_info.input_tensor.emplace_back(
std::make_shared<tensor::Tensor>(static_cast<int64_t>(0), kInt64));
(void)op_run_info->base_op_run_info.input_mask.emplace_back(kParameterDataTensorMask);
return;
} else {
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
}
std::vector<int64_t> axis = {};
(void)op_run_info->base_op_run_info.input_tensor.emplace_back(std::make_shared<tensor::Tensor>(axis, kInt64));
(void)op_run_info->base_op_run_info.input_mask.emplace_back(kValueNodeTensorMask);
return;
}
if (tuple_inputs[0]->isa<tensor::Tensor>()) {
PlantTensorTupleToVector(op_run_info, value_seq, op_prim, index);
@ -442,6 +436,15 @@ void DataConvert::ConvertValueToTensor(const FrontendOpRunInfoPtr &op_run_info,
}
tensor_ptr = std::make_shared<tensor::Tensor>(input, kInt64);
tensor_mask = kValueNodeTensorMask;
} else if (v->isa<Type>()) {
int64_t type_id = v->cast<TypePtr>()->type_id();
tensor_ptr = std::make_shared<tensor::Tensor>(type_id, kInt64);
tensor_mask = kValueNodeTensorMask;
} else if (v->isa<StringImm>()) {
auto value_string = GetValue<std::string>(v);
const ShapeVector shape = {1, SizeToLong(value_string.size())};
tensor_ptr = std::make_shared<tensor::Tensor>(kObjectTypeString, shape, value_string.data(), value_string.size());
tensor_mask = kValueNodeTensorMask;
} else if (v->isa<ValueSequence>()) {
ConvertTupleValueToTensor(op_run_info, v->cast<ValueSequencePtr>(), op_prim, index);
return;
@ -472,6 +475,11 @@ bool DataConvert::NeedConvertConstInputToAttr(const FrontendOpRunInfoPtr &op_run
return !input_to_attr_ptr->empty();
}
// Ascend const input to attr move to AscendConvertConstInputToAttr
if (device_target == kAscendDevice) {
return false;
}
auto reg_info = opt::ConvertOpInfoRegister::GetInstance().GetConvertOpInfo(
op_run_info->base_op_run_info.op_name, device_target, PyNativeAlgo::Common::IsDynamicShape(op_run_info));
if (reg_info == nullptr) {

View File

@ -130,7 +130,15 @@ bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, n
switch (result->second) {
case ATTR_DTYPE::ATTR_INT32:
(*attr_obj)[kJValue] = value->isa<Int32Imm>() ? GetValue<int>(value) : GetValue<int64_t>(value);
if (value->isa<Int32Imm>()) {
(*attr_obj)[kJValue] = GetValue<int>(value);
} else if (value->isa<Int64Imm>()) {
(*attr_obj)[kJValue] = GetValue<int64_t>(value);
} else {
MS_LOG(ERROR) << "Parse int32 attr value failed. Attr value:" << value->ToString()
<< ", Type:" << value->type_name();
return false;
}
break;
case ATTR_DTYPE::ATTR_INT64:
(*attr_obj)[kJValue] = GetValue<int64_t>(value);
@ -173,7 +181,8 @@ bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, n
break;
default:
MS_LOG(ERROR) << "Type: " << type << "not support";
MS_LOG(ERROR) << "Parse attr value failed. Attr Type: " << type << "not support. Attr value:" << value->ToString()
<< ", Type:" << value->type_name();
return false;
}
return true;
@ -475,6 +484,7 @@ void TbeJsonCreator::GenDesJsonCommon(nlohmann::json *output_desc) const {
}
void ParseConstValue(const mindspore::ValuePtr &value, nlohmann::json *json_obj) {
MS_EXCEPTION_IF_NULL(json_obj);
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);

View File

@ -167,7 +167,7 @@
#include "plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.h"
#include "plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h"
#include "plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.h"
#include "plugin/device/ascend/optimizer/mindir/reg_ascend_const_input_to_attr.h"
#include "plugin/device/ascend/optimizer/mindir/ascend_convert_const_input_to_attr.h"
#include "backend/common/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
#include "backend/common/pass/gradients_allreduce_depend_last_send.h"
#include "backend/common/pass/optimize_gradients_allreduce_overlap.h"
@ -681,6 +681,7 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2UnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::AllToAllUnifyMindIR>());
unify_mindir_pm->AddPass(std::make_shared<opt::AscendConvertConstInputToAttr>());
optimizer->AddPassManager(unify_mindir_pm);
(void)optimizer->Optimize(kernel_graph);

View File

@ -0,0 +1,333 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/mindir/ascend_convert_const_input_to_attr.h"
#include <algorithm>
#include <memory>
#include <vector>
#include <map>
#include <set>
#include "include/common/utils/utils.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
#include "plugin/device/ascend/optimizer/mindir/reg_ascend_const_input_to_attr.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
namespace mindspore::opt {
const AnfNodePtr AscendConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
return nullptr;
}
auto op_name = common::AnfAlgo::GetCNodeName(node);
auto is_dynamic = common::AnfAlgo::IsDynamicShape(node);
auto convert_op_info = ConvertOpInfoRegister::GetInstance().GetConvertOpInfo(op_name, kAscendDevice, is_dynamic);
if (convert_op_info == nullptr) {
return nullptr;
}
auto origin_op = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(origin_op);
auto ret_node = ConvertToTargetOp(origin_op, convert_op_info);
if ((ret_node != nullptr) && (ret_node != origin_op)) {
MS_LOG(INFO) << "Replace op " << origin_op->fullname_with_scope() << " debug string:" << origin_op->DebugString()
<< " with " << ret_node->fullname_with_scope() << " debug string:" << ret_node->DebugString()
<< ", is dynamic shape:" << is_dynamic;
}
return ret_node;
}
CNodePtr AscendConvertConstInputToAttr::ConvertToTargetOp(const CNodePtr &origin_op,
ConvertOpInfo *convert_op_info) const {
MS_EXCEPTION_IF_NULL(origin_op);
MS_EXCEPTION_IF_NULL(convert_op_info);
auto pre_check_func = convert_op_info->GetPreCheckFunc();
// check through op custom pre-check function
if (pre_check_func != nullptr) {
auto ret = pre_check_func(origin_op);
if (!ret) {
MS_LOG(DEBUG) << "Pre check function return Not Change for op " << origin_op->fullname_with_scope();
return origin_op;
}
}
// check supported if the op need
auto graph = origin_op->func_graph();
auto kernel_graph = graph->cast<KernelGraphPtr>();
auto is_need_check = convert_op_info->GetNeedCheckFlag();
if (is_need_check) {
auto is_dynamic = common::AnfAlgo::IsDynamicShape(origin_op);
// when cnode is a dynamic shape node, if origin op supported, use origin op
if (is_dynamic) {
auto ret = CheckAICoreSupported(origin_op);
if (ret) {
MS_LOG(DEBUG) << "Origin op " << origin_op->fullname_with_scope() << " is supported in this configuration";
return origin_op;
}
}
auto target_op = CreateTargetOp(origin_op, convert_op_info);
if (target_op == nullptr) {
MS_LOG(DEBUG) << "Create target op failed for node " << origin_op->fullname_with_scope();
return origin_op;
}
auto ret = CheckAICoreSupported(target_op);
if (!ret) {
return origin_op;
}
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(origin_op, target_op);
}
return target_op;
} else {
auto target_op = CreateTargetOp(origin_op, convert_op_info);
if (target_op == nullptr) {
MS_LOG(DEBUG) << "Create target op failed for node " << origin_op->fullname_with_scope();
return origin_op;
}
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(origin_op, target_op);
}
return target_op;
}
}
template <typename T, typename Scalar>
ValuePtr GetTensorValue(const tensor::TensorPtr &tensor) {
ValuePtr ret;
auto tensor_value = TensorValueToVector<T>(tensor);
if (tensor_value.size() == 1) {
ret = std::make_shared<Scalar>(tensor_value[0]);
} else {
std::vector<ValuePtr> value_vec;
for (const auto &elem : tensor_value) {
auto value = std::make_shared<Scalar>(elem);
MS_EXCEPTION_IF_NULL(value);
value_vec.push_back(value);
}
ret = std::make_shared<ValueTuple>(value_vec);
}
return ret;
}
ValuePtr CreateValueFromTensor(const tensor::TensorPtr &tensor) {
TypePtr data_type = tensor->Dtype();
MS_EXCEPTION_IF_NULL(data_type);
TypeId type_id = data_type->type_id();
ValuePtr ret;
switch (type_id) {
case kNumberTypeInt8: {
ret = GetTensorValue<int8_t, Int8Imm>(tensor);
break;
}
case kNumberTypeUInt8: {
ret = GetTensorValue<uint8_t, UInt8Imm>(tensor);
break;
}
case kNumberTypeInt16: {
ret = GetTensorValue<int16_t, Int16Imm>(tensor);
break;
}
case kNumberTypeUInt16: {
ret = GetTensorValue<uint16_t, UInt16Imm>(tensor);
break;
}
case kNumberTypeInt32: {
ret = GetTensorValue<int32_t, Int32Imm>(tensor);
break;
}
case kNumberTypeUInt32: {
ret = GetTensorValue<uint32_t, UInt32Imm>(tensor);
break;
}
case kNumberTypeInt64: {
ret = GetTensorValue<int64_t, Int64Imm>(tensor);
break;
}
case kNumberTypeUInt64: {
ret = GetTensorValue<uint64_t, UInt64Imm>(tensor);
break;
}
case kNumberTypeFloat32: {
ret = GetTensorValue<float, FP32Imm>(tensor);
break;
}
case kNumberTypeFloat64: {
ret = GetTensorValue<double, FP64Imm>(tensor);
break;
}
default:
MS_LOG(EXCEPTION) << "Can't parse attr value :" << tensor->ToString() << ", Type:" << tensor->type_name();
}
return ret;
}
CNodePtr AscendConvertConstInputToAttr::CreateTargetOp(const CNodePtr &origin_op,
ConvertOpInfo *convert_op_info) const {
MS_EXCEPTION_IF_NULL(origin_op);
MS_EXCEPTION_IF_NULL(convert_op_info);
auto target_op_name = convert_op_info->GetTargetOpName();
auto input_attr_info_map = convert_op_info->GetInputAttrInfoMap();
auto origin_primitive = GetCNodePrimitive(origin_op);
MS_EXCEPTION_IF_NULL(origin_primitive);
auto target_primitive = std::make_shared<Primitive>(target_op_name);
MS_EXCEPTION_IF_NULL(target_primitive);
target_primitive->SetAttrs(origin_primitive->attrs());
std::vector<AnfNodePtr> target_inputs;
auto inputs = origin_op->inputs();
target_inputs.push_back(inputs[0]);
auto input_names = origin_primitive->GetAttr(kAttrInputNames);
if (input_names == nullptr) {
MS_LOG(WARNING) << "input_names are nullptr in cnode[" << origin_op->DebugString() << "]";
return nullptr;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
for (size_t i = 0; i < inputs.size() - 1; ++i) {
auto input_node = inputs[i + 1];
MS_EXCEPTION_IF_NULL(input_node);
if (IsPrimitiveCNode(input_node, prim::kPrimDepend)) {
input_node = AnfUtils::VisitKernel(input_node, 0).first;
}
auto iter = input_attr_info_map.find(i);
if (iter != input_attr_info_map.end() && input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
auto ret = ConvertInputToAttr(origin_op, target_op_name, input_names_vec, i, input_node, iter, target_primitive);
if (!ret) {
return nullptr;
}
} else {
target_inputs.push_back(inputs[i + 1]);
}
}
// Update target_op's inputs
target_inputs[0] = NewValueNode(target_primitive);
auto graph = origin_op->func_graph();
MS_EXCEPTION_IF_NULL(graph);
auto target_op = opt::NewCNode(target_inputs, graph, {origin_op});
MS_EXCEPTION_IF_NULL(target_op);
target_op->set_abstract(origin_op->abstract());
target_op->set_scope(origin_op->scope());
target_op->set_primal_attrs(origin_op->primal_attrs());
target_op->set_attrs(origin_op->attrs());
auto is_dynamic = common::AnfAlgo::IsDynamicShape(origin_op);
MS_LOG(DEBUG) << "Create op " << target_op->fullname_with_scope() << " debug string:" << target_op->DebugString()
<< " from " << origin_op->fullname_with_scope() << " debug string:" << origin_op->DebugString()
<< ", is dynamic shape:" << is_dynamic;
return target_op;
}
bool AscendConvertConstInputToAttr::ConvertInputToAttr(const CNodePtr &origin_op, const string &target_op_name,
const std::vector<std::string> &input_names_vec, size_t i,
const std::shared_ptr<AnfNode> &input_node,
const std::map<size_t, InputAttrInfo>::iterator &iter,
const std::shared_ptr<Primitive> &target_primitive) const {
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
MS_LOG(DEBUG) << "start erase input[" << i
<< "] of cnode[" + origin_op->DebugString() + "], origin value:" << value_node->ToString()
<< ", Type:" << value_node->type_name();
if (i >= input_names_vec.size()) {
MS_LOG(WARNING) << "Input index is invalid. input index: " << i << ", input name size " << input_names_vec.size();
return false;
}
auto value = value_node->value();
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
if (tensor->data().const_data() == nullptr) {
MS_LOG(DEBUG) << "Const input data ptr is null from op " << origin_op->fullname_with_scope() << "'s input " << i;
return false;
}
value = CreateValueFromTensor(tensor);
}
auto attr_name = GetAttrName(target_op_name, iter, input_names_vec[i]);
value = UpdateAttrValue(origin_op, iter, value, attr_name);
MS_LOG(DEBUG) << "new attr value:" << value_node->ToString() << ", Type:" << value_node->type_name();
target_primitive->set_attr(attr_name, value);
return true;
}
std::string AscendConvertConstInputToAttr::GetAttrName(const string &target_op_name,
const std::map<size_t, InputAttrInfo>::iterator &iter,
const string &input_name) const {
auto attr_name = iter->second.GetAttrName();
if (attr_name.empty()) {
MS_LOG(INFO) << "Attr name is empty for op " << target_op_name << ", use input name " << input_name << " instead.";
attr_name = input_name;
} else if (attr_name != input_name) {
MS_LOG(WARNING) << "Attr name not match input name: " << attr_name << " vs " << input_name;
}
return attr_name;
}
ValuePtr AscendConvertConstInputToAttr::UpdateAttrValue(const CNodePtr &origin_op,
const std::map<size_t, InputAttrInfo>::iterator &iter,
const ValuePtr &value, const string &attr_name) const {
ValuePtr ret = value;
auto attr_dtype = iter->second.GetAttrDataType();
if (attr_dtype.empty()) {
// TODO(laiyongqiang): exception, it means attr info is wrong when refactory done.
auto op_name = common::AnfAlgo::GetCNodeName(origin_op);
auto op_info_ptr = kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, origin_op);
if (op_info_ptr) {
auto op_info_attrs_ptr = op_info_ptr->attrs_ptr();
for (const auto &op_info_attr_ptr : op_info_attrs_ptr) {
std::string op_attr_name = op_info_attr_ptr->name();
if (op_attr_name == attr_name) {
attr_dtype = op_info_attr_ptr->type();
break;
}
}
}
}
if (!attr_dtype.empty()) {
ret = UpdateAttrValueByDtype(value, attr_dtype);
}
return ret;
}
ValuePtr AscendConvertConstInputToAttr::UpdateAttrValueByDtype(const ValuePtr &value,
const std::string &attr_data_type) const {
static std::set<std::string> kListDataType = {"listInt", "listStr", "listBool", "listFloat"};
auto iter = kListDataType.find(attr_data_type);
ValuePtr ret = value;
if (iter != kListDataType.end()) {
if (!value->isa<ValueSequence>()) {
std::vector<ValuePtr> value_vec;
value_vec.push_back(value);
ret = std::make_shared<ValueTuple>(value_vec);
}
}
return ret;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,56 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ASCEND_CONVERT_CONST_INPUT_TO_ATTR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ASCEND_CONVERT_CONST_INPUT_TO_ATTR_H_
#include <string>
#include <memory>
#include <vector>
#include <map>
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "ir/anf.h"
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/const_input_to_attr_factory.h"
namespace mindspore {
namespace opt {
class AscendConvertConstInputToAttr : public PatternProcessPass {
public:
explicit AscendConvertConstInputToAttr(bool multigraph = true)
: PatternProcessPass("ascend_convert_const_input_to_attr", multigraph) {}
~AscendConvertConstInputToAttr() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
private:
CNodePtr ConvertToTargetOp(const CNodePtr &origin_op, mindspore::opt::ConvertOpInfo *convert_op_info) const;
CNodePtr CreateTargetOp(const CNodePtr &origin_op, ConvertOpInfo *convert_op_info) const;
bool ConvertInputToAttr(const CNodePtr &origin_op, const string &target_op_name,
const std::vector<std::string> &input_names_vec, size_t i,
const std::shared_ptr<AnfNode> &input_node,
const std::map<size_t, InputAttrInfo>::iterator &iter,
const std::shared_ptr<Primitive> &target_primitive) const;
string GetAttrName(const string &target_op_name, const std::map<size_t, InputAttrInfo>::iterator &iter,
const string &input_name) const;
ValuePtr UpdateAttrValueByDtype(const ValuePtr &value, const string &attr_data_type) const;
ValuePtr UpdateAttrValue(const CNodePtr &origin_op, const std::map<size_t, InputAttrInfo>::iterator &iter,
const ValuePtr &value, const string &attr_name) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ASCEND_CONVERT_CONST_INPUT_TO_ATTR_H_

View File

@ -90,13 +90,23 @@ RER_ASCEND_STATIC_CONST_TO_ATTR(kSplitOpName, 0);
RER_ASCEND_STATIC_CONST_TO_ATTR(kStridedSliceAssignOpName, 1, 2, 3);
RER_ASCEND_STATIC_CONST_TO_ATTR(kStridedSliceOpName, 1, 2, 3);
RER_ASCEND_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
RER_ASCEND_STATIC_CONST_TO_ATTR(kTileOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
RER_ASCEND_STATIC_CONST_TO_ATTR(kUnsortedSegmentMaxOpName, 2);
RER_ASCEND_STATIC_CONST_TO_ATTR(kUnsortedSegmentMinOpName, 2);
RER_ASCEND_STATIC_CONST_TO_ATTR(kUnsortedSegmentProdOpName, 2);
RER_ASCEND_STATIC_CONST_TO_ATTR(kUnsortedSegmentSumOpName, 2);
// =============================== new reg interface =================================================
#define REG_ASCEND_CONST_TO_ATTR(origin_op_name, target_op_name) \
REG_CONST_TO_ATTR(origin_op_name, target_op_name, kAscendDevice, false)
// RTS OP
REG_ASCEND_CONST_TO_ATTR(kTensorCopySlicesOpName, kTensorCopySlicesOpName)
.SetNeedCheckSupported(false)
.SetInputAttrInfo(2, "begin", "listInt")
.SetInputAttrInfo(3, "end", "listInt")
.SetInputAttrInfo(4, "strides", "listInt");
} // namespace mindspore::opt
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_MINDIR_REG_ASCEND_CONST_INPUT_TO_ATTR_H_

View File

@ -651,7 +651,7 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
}
CNodePtr softmax_node;
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, *this, true);
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, *this);
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node, *this);
std::vector<AnfNodePtr> softmax_node_outputs;

View File

@ -1236,9 +1236,6 @@ bool AnfAlgo::IsKernelDynamicImpl(const AnfNodePtr &node) {
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!HasNodeAttr(kAttrIsKernelDynamicImpl, cnode)) {
return false;
}
return GetBooleanAttr(node, kAttrIsKernelDynamicImpl);
}

View File

@ -16,6 +16,7 @@
#include "common/backend_common_test.h"
#include "common/py_func_graph_fetcher.h"
#include "backend/common/optimizer/optimizer.h"
#include "plugin/device/ascend/optimizer/mindir/ascend_convert_const_input_to_attr.h"
#include "plugin/device/ascend/optimizer/ir_fusion/confusion_softmax_grad_rule.h"
#include "include/common/debug/anf_ir_dump.h"
@ -43,6 +44,7 @@ TEST_F(TestHWOptimizeConfusionSoftmaxGradRule, test_confusion_softmax_grad_rule)
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AscendConvertConstInputToAttr>());
pm->AddPass(std::make_shared<opt::ConfusionSoftmaxGradRule>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);