forked from mindspore-Ecosystem/mindspore
!40668 move ascend const input to attr pass to unify mindir
Merge pull request !40668 from laiyongqiang/new_pass
This commit is contained in:
commit
ccf3db786c
|
@ -20,6 +20,7 @@
|
|||
"mindspore/mindspore/ccsrc/runtime/hardware/device_context.h" "readability/braces"
|
||||
"mindspore/mindspore/ccsrc/transform/graph_ir/convert.h" "runtime/references"
|
||||
"mindspore/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/gather_grad_kernels.cc" "build/include"
|
||||
"mindspore/mindspore/ccsrc/backend/common/optimizer/const_input_to_attr_factory.h" "runtime/explicit"
|
||||
|
||||
# Modelzoo
|
||||
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
|
||||
|
|
|
@ -129,7 +129,7 @@ class BACKEND_EXPORT ConvertOpInfoRegister {
|
|||
class RegisterHelper {
|
||||
public:
|
||||
RegisterHelper(const std::string &name, const std::string &device_name, bool is_dynamic_shape, int len, ...);
|
||||
explicit RegisterHelper(const ConvertOpInfo &convert_op_info);
|
||||
RegisterHelper(const ConvertOpInfo &convert_op_info);
|
||||
~RegisterHelper() = default;
|
||||
|
||||
private:
|
||||
|
|
|
@ -42,6 +42,12 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
|||
<< ", primitive_target: " << primitive_target;
|
||||
backend = primitive_target;
|
||||
}
|
||||
|
||||
// Ascend const input to attr move to AscendConvertConstInputToAttr
|
||||
if (backend == kAscendDevice) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto is_dynamic_shape = common::AnfAlgo::IsDynamicShape(node);
|
||||
mindspore::HashSet<size_t> input_to_attr = {};
|
||||
auto reg_info = opt::ConvertOpInfoRegister::GetInstance().GetConvertOpInfo(name, backend, is_dynamic_shape);
|
||||
|
|
|
@ -30,8 +30,6 @@
|
|||
namespace mindspore {
|
||||
namespace pynative {
|
||||
namespace PyNativeAlgo {
|
||||
const std::set<std::string> kAxisNone = {"ReduceSum"};
|
||||
|
||||
std::string Common::GetIdByValue(const ValuePtr &v) {
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
if (v->isa<tensor::Tensor>()) {
|
||||
|
@ -398,14 +396,10 @@ void DataConvert::ConvertTupleValueToTensor(const FrontendOpRunInfoPtr &op_run_i
|
|||
|
||||
const auto &tuple_inputs = value_seq->value();
|
||||
if (tuple_inputs.empty()) {
|
||||
if (kAxisNone.find(op_prim->name()) != kAxisNone.end()) {
|
||||
(void)op_run_info->base_op_run_info.input_tensor.emplace_back(
|
||||
std::make_shared<tensor::Tensor>(static_cast<int64_t>(0), kInt64));
|
||||
(void)op_run_info->base_op_run_info.input_mask.emplace_back(kParameterDataTensorMask);
|
||||
return;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
|
||||
}
|
||||
std::vector<int64_t> axis = {};
|
||||
(void)op_run_info->base_op_run_info.input_tensor.emplace_back(std::make_shared<tensor::Tensor>(axis, kInt64));
|
||||
(void)op_run_info->base_op_run_info.input_mask.emplace_back(kValueNodeTensorMask);
|
||||
return;
|
||||
}
|
||||
if (tuple_inputs[0]->isa<tensor::Tensor>()) {
|
||||
PlantTensorTupleToVector(op_run_info, value_seq, op_prim, index);
|
||||
|
@ -442,6 +436,15 @@ void DataConvert::ConvertValueToTensor(const FrontendOpRunInfoPtr &op_run_info,
|
|||
}
|
||||
tensor_ptr = std::make_shared<tensor::Tensor>(input, kInt64);
|
||||
tensor_mask = kValueNodeTensorMask;
|
||||
} else if (v->isa<Type>()) {
|
||||
int64_t type_id = v->cast<TypePtr>()->type_id();
|
||||
tensor_ptr = std::make_shared<tensor::Tensor>(type_id, kInt64);
|
||||
tensor_mask = kValueNodeTensorMask;
|
||||
} else if (v->isa<StringImm>()) {
|
||||
auto value_string = GetValue<std::string>(v);
|
||||
const ShapeVector shape = {1, SizeToLong(value_string.size())};
|
||||
tensor_ptr = std::make_shared<tensor::Tensor>(kObjectTypeString, shape, value_string.data(), value_string.size());
|
||||
tensor_mask = kValueNodeTensorMask;
|
||||
} else if (v->isa<ValueSequence>()) {
|
||||
ConvertTupleValueToTensor(op_run_info, v->cast<ValueSequencePtr>(), op_prim, index);
|
||||
return;
|
||||
|
@ -472,6 +475,11 @@ bool DataConvert::NeedConvertConstInputToAttr(const FrontendOpRunInfoPtr &op_run
|
|||
return !input_to_attr_ptr->empty();
|
||||
}
|
||||
|
||||
// Ascend const input to attr move to AscendConvertConstInputToAttr
|
||||
if (device_target == kAscendDevice) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reg_info = opt::ConvertOpInfoRegister::GetInstance().GetConvertOpInfo(
|
||||
op_run_info->base_op_run_info.op_name, device_target, PyNativeAlgo::Common::IsDynamicShape(op_run_info));
|
||||
if (reg_info == nullptr) {
|
||||
|
|
|
@ -130,7 +130,15 @@ bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, n
|
|||
|
||||
switch (result->second) {
|
||||
case ATTR_DTYPE::ATTR_INT32:
|
||||
(*attr_obj)[kJValue] = value->isa<Int32Imm>() ? GetValue<int>(value) : GetValue<int64_t>(value);
|
||||
if (value->isa<Int32Imm>()) {
|
||||
(*attr_obj)[kJValue] = GetValue<int>(value);
|
||||
} else if (value->isa<Int64Imm>()) {
|
||||
(*attr_obj)[kJValue] = GetValue<int64_t>(value);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Parse int32 attr value failed. Attr value:" << value->ToString()
|
||||
<< ", Type:" << value->type_name();
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case ATTR_DTYPE::ATTR_INT64:
|
||||
(*attr_obj)[kJValue] = GetValue<int64_t>(value);
|
||||
|
@ -173,7 +181,8 @@ bool ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value, n
|
|||
break;
|
||||
|
||||
default:
|
||||
MS_LOG(ERROR) << "Type: " << type << "not support";
|
||||
MS_LOG(ERROR) << "Parse attr value failed. Attr Type: " << type << "not support. Attr value:" << value->ToString()
|
||||
<< ", Type:" << value->type_name();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -475,6 +484,7 @@ void TbeJsonCreator::GenDesJsonCommon(nlohmann::json *output_desc) const {
|
|||
}
|
||||
|
||||
void ParseConstValue(const mindspore::ValuePtr &value, nlohmann::json *json_obj) {
|
||||
MS_EXCEPTION_IF_NULL(json_obj);
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
|
|
|
@ -167,7 +167,7 @@
|
|||
#include "plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.h"
|
||||
#include "plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h"
|
||||
#include "plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.h"
|
||||
#include "plugin/device/ascend/optimizer/mindir/reg_ascend_const_input_to_attr.h"
|
||||
#include "plugin/device/ascend/optimizer/mindir/ascend_convert_const_input_to_attr.h"
|
||||
#include "backend/common/pass/adjust_depend_for_parallel_optimizer_recompute_all_gather.h"
|
||||
#include "backend/common/pass/gradients_allreduce_depend_last_send.h"
|
||||
#include "backend/common/pass/optimize_gradients_allreduce_overlap.h"
|
||||
|
@ -681,6 +681,7 @@ void AscendUnifyMindIR(const std::shared_ptr<session::KernelGraph> &kernel_graph
|
|||
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2UnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::NeighborExchangeV2GradUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::AllToAllUnifyMindIR>());
|
||||
unify_mindir_pm->AddPass(std::make_shared<opt::AscendConvertConstInputToAttr>());
|
||||
|
||||
optimizer->AddPassManager(unify_mindir_pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -0,0 +1,333 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/mindir/ascend_convert_const_input_to_attr.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/device/ascend/optimizer/ascend_helper.h"
|
||||
#include "plugin/device/ascend/optimizer/mindir/reg_ascend_const_input_to_attr.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
const AnfNodePtr AscendConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(node);
|
||||
auto is_dynamic = common::AnfAlgo::IsDynamicShape(node);
|
||||
auto convert_op_info = ConvertOpInfoRegister::GetInstance().GetConvertOpInfo(op_name, kAscendDevice, is_dynamic);
|
||||
if (convert_op_info == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto origin_op = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(origin_op);
|
||||
auto ret_node = ConvertToTargetOp(origin_op, convert_op_info);
|
||||
if ((ret_node != nullptr) && (ret_node != origin_op)) {
|
||||
MS_LOG(INFO) << "Replace op " << origin_op->fullname_with_scope() << " debug string:" << origin_op->DebugString()
|
||||
<< " with " << ret_node->fullname_with_scope() << " debug string:" << ret_node->DebugString()
|
||||
<< ", is dynamic shape:" << is_dynamic;
|
||||
}
|
||||
return ret_node;
|
||||
}
|
||||
|
||||
CNodePtr AscendConvertConstInputToAttr::ConvertToTargetOp(const CNodePtr &origin_op,
|
||||
ConvertOpInfo *convert_op_info) const {
|
||||
MS_EXCEPTION_IF_NULL(origin_op);
|
||||
MS_EXCEPTION_IF_NULL(convert_op_info);
|
||||
auto pre_check_func = convert_op_info->GetPreCheckFunc();
|
||||
// check through op custom pre-check function
|
||||
if (pre_check_func != nullptr) {
|
||||
auto ret = pre_check_func(origin_op);
|
||||
if (!ret) {
|
||||
MS_LOG(DEBUG) << "Pre check function return Not Change for op " << origin_op->fullname_with_scope();
|
||||
return origin_op;
|
||||
}
|
||||
}
|
||||
|
||||
// check supported if the op need
|
||||
auto graph = origin_op->func_graph();
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
auto is_need_check = convert_op_info->GetNeedCheckFlag();
|
||||
if (is_need_check) {
|
||||
auto is_dynamic = common::AnfAlgo::IsDynamicShape(origin_op);
|
||||
// when cnode is a dynamic shape node, if origin op supported, use origin op
|
||||
if (is_dynamic) {
|
||||
auto ret = CheckAICoreSupported(origin_op);
|
||||
if (ret) {
|
||||
MS_LOG(DEBUG) << "Origin op " << origin_op->fullname_with_scope() << " is supported in this configuration";
|
||||
return origin_op;
|
||||
}
|
||||
}
|
||||
|
||||
auto target_op = CreateTargetOp(origin_op, convert_op_info);
|
||||
if (target_op == nullptr) {
|
||||
MS_LOG(DEBUG) << "Create target op failed for node " << origin_op->fullname_with_scope();
|
||||
return origin_op;
|
||||
}
|
||||
|
||||
auto ret = CheckAICoreSupported(target_op);
|
||||
if (!ret) {
|
||||
return origin_op;
|
||||
}
|
||||
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(origin_op, target_op);
|
||||
}
|
||||
return target_op;
|
||||
} else {
|
||||
auto target_op = CreateTargetOp(origin_op, convert_op_info);
|
||||
if (target_op == nullptr) {
|
||||
MS_LOG(DEBUG) << "Create target op failed for node " << origin_op->fullname_with_scope();
|
||||
return origin_op;
|
||||
}
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->FrontBackendlMapUpdate(origin_op, target_op);
|
||||
}
|
||||
return target_op;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename Scalar>
|
||||
ValuePtr GetTensorValue(const tensor::TensorPtr &tensor) {
|
||||
ValuePtr ret;
|
||||
auto tensor_value = TensorValueToVector<T>(tensor);
|
||||
if (tensor_value.size() == 1) {
|
||||
ret = std::make_shared<Scalar>(tensor_value[0]);
|
||||
} else {
|
||||
std::vector<ValuePtr> value_vec;
|
||||
for (const auto &elem : tensor_value) {
|
||||
auto value = std::make_shared<Scalar>(elem);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
value_vec.push_back(value);
|
||||
}
|
||||
ret = std::make_shared<ValueTuple>(value_vec);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
ValuePtr CreateValueFromTensor(const tensor::TensorPtr &tensor) {
|
||||
TypePtr data_type = tensor->Dtype();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
TypeId type_id = data_type->type_id();
|
||||
ValuePtr ret;
|
||||
switch (type_id) {
|
||||
case kNumberTypeInt8: {
|
||||
ret = GetTensorValue<int8_t, Int8Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeUInt8: {
|
||||
ret = GetTensorValue<uint8_t, UInt8Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeInt16: {
|
||||
ret = GetTensorValue<int16_t, Int16Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeUInt16: {
|
||||
ret = GetTensorValue<uint16_t, UInt16Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeInt32: {
|
||||
ret = GetTensorValue<int32_t, Int32Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeUInt32: {
|
||||
ret = GetTensorValue<uint32_t, UInt32Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeInt64: {
|
||||
ret = GetTensorValue<int64_t, Int64Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeUInt64: {
|
||||
ret = GetTensorValue<uint64_t, UInt64Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeFloat32: {
|
||||
ret = GetTensorValue<float, FP32Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
case kNumberTypeFloat64: {
|
||||
ret = GetTensorValue<double, FP64Imm>(tensor);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Can't parse attr value :" << tensor->ToString() << ", Type:" << tensor->type_name();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CNodePtr AscendConvertConstInputToAttr::CreateTargetOp(const CNodePtr &origin_op,
|
||||
ConvertOpInfo *convert_op_info) const {
|
||||
MS_EXCEPTION_IF_NULL(origin_op);
|
||||
MS_EXCEPTION_IF_NULL(convert_op_info);
|
||||
auto target_op_name = convert_op_info->GetTargetOpName();
|
||||
auto input_attr_info_map = convert_op_info->GetInputAttrInfoMap();
|
||||
|
||||
auto origin_primitive = GetCNodePrimitive(origin_op);
|
||||
MS_EXCEPTION_IF_NULL(origin_primitive);
|
||||
auto target_primitive = std::make_shared<Primitive>(target_op_name);
|
||||
MS_EXCEPTION_IF_NULL(target_primitive);
|
||||
target_primitive->SetAttrs(origin_primitive->attrs());
|
||||
std::vector<AnfNodePtr> target_inputs;
|
||||
auto inputs = origin_op->inputs();
|
||||
target_inputs.push_back(inputs[0]);
|
||||
|
||||
auto input_names = origin_primitive->GetAttr(kAttrInputNames);
|
||||
if (input_names == nullptr) {
|
||||
MS_LOG(WARNING) << "input_names are nullptr in cnode[" << origin_op->DebugString() << "]";
|
||||
return nullptr;
|
||||
}
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
|
||||
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||
auto input_node = inputs[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPrimitiveCNode(input_node, prim::kPrimDepend)) {
|
||||
input_node = AnfUtils::VisitKernel(input_node, 0).first;
|
||||
}
|
||||
|
||||
auto iter = input_attr_info_map.find(i);
|
||||
if (iter != input_attr_info_map.end() && input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
|
||||
auto ret = ConvertInputToAttr(origin_op, target_op_name, input_names_vec, i, input_node, iter, target_primitive);
|
||||
if (!ret) {
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
target_inputs.push_back(inputs[i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Update target_op's inputs
|
||||
target_inputs[0] = NewValueNode(target_primitive);
|
||||
auto graph = origin_op->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto target_op = opt::NewCNode(target_inputs, graph, {origin_op});
|
||||
MS_EXCEPTION_IF_NULL(target_op);
|
||||
target_op->set_abstract(origin_op->abstract());
|
||||
target_op->set_scope(origin_op->scope());
|
||||
target_op->set_primal_attrs(origin_op->primal_attrs());
|
||||
target_op->set_attrs(origin_op->attrs());
|
||||
auto is_dynamic = common::AnfAlgo::IsDynamicShape(origin_op);
|
||||
MS_LOG(DEBUG) << "Create op " << target_op->fullname_with_scope() << " debug string:" << target_op->DebugString()
|
||||
<< " from " << origin_op->fullname_with_scope() << " debug string:" << origin_op->DebugString()
|
||||
<< ", is dynamic shape:" << is_dynamic;
|
||||
return target_op;
|
||||
}
|
||||
|
||||
bool AscendConvertConstInputToAttr::ConvertInputToAttr(const CNodePtr &origin_op, const string &target_op_name,
|
||||
const std::vector<std::string> &input_names_vec, size_t i,
|
||||
const std::shared_ptr<AnfNode> &input_node,
|
||||
const std::map<size_t, InputAttrInfo>::iterator &iter,
|
||||
const std::shared_ptr<Primitive> &target_primitive) const {
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_LOG(DEBUG) << "start erase input[" << i
|
||||
<< "] of cnode[" + origin_op->DebugString() + "], origin value:" << value_node->ToString()
|
||||
<< ", Type:" << value_node->type_name();
|
||||
if (i >= input_names_vec.size()) {
|
||||
MS_LOG(WARNING) << "Input index is invalid. input index: " << i << ", input name size " << input_names_vec.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
auto value = value_node->value();
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
if (tensor->data().const_data() == nullptr) {
|
||||
MS_LOG(DEBUG) << "Const input data ptr is null from op " << origin_op->fullname_with_scope() << "'s input " << i;
|
||||
return false;
|
||||
}
|
||||
value = CreateValueFromTensor(tensor);
|
||||
}
|
||||
|
||||
auto attr_name = GetAttrName(target_op_name, iter, input_names_vec[i]);
|
||||
value = UpdateAttrValue(origin_op, iter, value, attr_name);
|
||||
MS_LOG(DEBUG) << "new attr value:" << value_node->ToString() << ", Type:" << value_node->type_name();
|
||||
target_primitive->set_attr(attr_name, value);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string AscendConvertConstInputToAttr::GetAttrName(const string &target_op_name,
|
||||
const std::map<size_t, InputAttrInfo>::iterator &iter,
|
||||
const string &input_name) const {
|
||||
auto attr_name = iter->second.GetAttrName();
|
||||
if (attr_name.empty()) {
|
||||
MS_LOG(INFO) << "Attr name is empty for op " << target_op_name << ", use input name " << input_name << " instead.";
|
||||
attr_name = input_name;
|
||||
} else if (attr_name != input_name) {
|
||||
MS_LOG(WARNING) << "Attr name not match input name: " << attr_name << " vs " << input_name;
|
||||
}
|
||||
return attr_name;
|
||||
}
|
||||
|
||||
ValuePtr AscendConvertConstInputToAttr::UpdateAttrValue(const CNodePtr &origin_op,
|
||||
const std::map<size_t, InputAttrInfo>::iterator &iter,
|
||||
const ValuePtr &value, const string &attr_name) const {
|
||||
ValuePtr ret = value;
|
||||
auto attr_dtype = iter->second.GetAttrDataType();
|
||||
if (attr_dtype.empty()) {
|
||||
// TODO(laiyongqiang): exception, it means attr info is wrong when refactory done.
|
||||
auto op_name = common::AnfAlgo::GetCNodeName(origin_op);
|
||||
auto op_info_ptr = kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, origin_op);
|
||||
if (op_info_ptr) {
|
||||
auto op_info_attrs_ptr = op_info_ptr->attrs_ptr();
|
||||
for (const auto &op_info_attr_ptr : op_info_attrs_ptr) {
|
||||
std::string op_attr_name = op_info_attr_ptr->name();
|
||||
if (op_attr_name == attr_name) {
|
||||
attr_dtype = op_info_attr_ptr->type();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!attr_dtype.empty()) {
|
||||
ret = UpdateAttrValueByDtype(value, attr_dtype);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
ValuePtr AscendConvertConstInputToAttr::UpdateAttrValueByDtype(const ValuePtr &value,
|
||||
const std::string &attr_data_type) const {
|
||||
static std::set<std::string> kListDataType = {"listInt", "listStr", "listBool", "listFloat"};
|
||||
auto iter = kListDataType.find(attr_data_type);
|
||||
ValuePtr ret = value;
|
||||
if (iter != kListDataType.end()) {
|
||||
if (!value->isa<ValueSequence>()) {
|
||||
std::vector<ValuePtr> value_vec;
|
||||
value_vec.push_back(value);
|
||||
ret = std::make_shared<ValueTuple>(value_vec);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ASCEND_CONVERT_CONST_INPUT_TO_ATTR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ASCEND_CONVERT_CONST_INPUT_TO_ATTR_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/const_input_to_attr_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AscendConvertConstInputToAttr : public PatternProcessPass {
|
||||
public:
|
||||
explicit AscendConvertConstInputToAttr(bool multigraph = true)
|
||||
: PatternProcessPass("ascend_convert_const_input_to_attr", multigraph) {}
|
||||
~AscendConvertConstInputToAttr() override = default;
|
||||
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr ConvertToTargetOp(const CNodePtr &origin_op, mindspore::opt::ConvertOpInfo *convert_op_info) const;
|
||||
CNodePtr CreateTargetOp(const CNodePtr &origin_op, ConvertOpInfo *convert_op_info) const;
|
||||
bool ConvertInputToAttr(const CNodePtr &origin_op, const string &target_op_name,
|
||||
const std::vector<std::string> &input_names_vec, size_t i,
|
||||
const std::shared_ptr<AnfNode> &input_node,
|
||||
const std::map<size_t, InputAttrInfo>::iterator &iter,
|
||||
const std::shared_ptr<Primitive> &target_primitive) const;
|
||||
string GetAttrName(const string &target_op_name, const std::map<size_t, InputAttrInfo>::iterator &iter,
|
||||
const string &input_name) const;
|
||||
ValuePtr UpdateAttrValueByDtype(const ValuePtr &value, const string &attr_data_type) const;
|
||||
ValuePtr UpdateAttrValue(const CNodePtr &origin_op, const std::map<size_t, InputAttrInfo>::iterator &iter,
|
||||
const ValuePtr &value, const string &attr_name) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ASCEND_CONVERT_CONST_INPUT_TO_ATTR_H_
|
|
@ -90,13 +90,23 @@ RER_ASCEND_STATIC_CONST_TO_ATTR(kSplitOpName, 0);
|
|||
RER_ASCEND_STATIC_CONST_TO_ATTR(kStridedSliceAssignOpName, 1, 2, 3);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kStridedSliceOpName, 1, 2, 3);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kSubscalarOpName, 1);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kTileOpName, 1);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kUnsortedSegmentMaxOpName, 2);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kUnsortedSegmentMinOpName, 2);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kUnsortedSegmentProdOpName, 2);
|
||||
RER_ASCEND_STATIC_CONST_TO_ATTR(kUnsortedSegmentSumOpName, 2);
|
||||
|
||||
// =============================== new reg interface =================================================
|
||||
#define REG_ASCEND_CONST_TO_ATTR(origin_op_name, target_op_name) \
|
||||
REG_CONST_TO_ATTR(origin_op_name, target_op_name, kAscendDevice, false)
|
||||
|
||||
// RTS OP
|
||||
REG_ASCEND_CONST_TO_ATTR(kTensorCopySlicesOpName, kTensorCopySlicesOpName)
|
||||
.SetNeedCheckSupported(false)
|
||||
.SetInputAttrInfo(2, "begin", "listInt")
|
||||
.SetInputAttrInfo(3, "end", "listInt")
|
||||
.SetInputAttrInfo(4, "strides", "listInt");
|
||||
} // namespace mindspore::opt
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_OPTIMIZER_MINDIR_REG_ASCEND_CONST_INPUT_TO_ATTR_H_
|
||||
|
|
|
@ -651,7 +651,7 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
|
|||
}
|
||||
|
||||
CNodePtr softmax_node;
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, *this, true);
|
||||
auto one_hot_node = CreateOneHot(graph, sparse_softmax_node, *this);
|
||||
softmax_node = CreateSoftmaxCrossEntropyWithLogits(graph, sparse_softmax_node, one_hot_node, *this);
|
||||
|
||||
std::vector<AnfNodePtr> softmax_node_outputs;
|
||||
|
|
|
@ -1236,9 +1236,6 @@ bool AnfAlgo::IsKernelDynamicImpl(const AnfNodePtr &node) {
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!HasNodeAttr(kAttrIsKernelDynamicImpl, cnode)) {
|
||||
return false;
|
||||
}
|
||||
return GetBooleanAttr(node, kAttrIsKernelDynamicImpl);
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "plugin/device/ascend/optimizer/mindir/ascend_convert_const_input_to_attr.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fusion/confusion_softmax_grad_rule.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
|
||||
|
@ -43,6 +44,7 @@ TEST_F(TestHWOptimizeConfusionSoftmaxGradRule, test_confusion_softmax_grad_rule)
|
|||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AscendConvertConstInputToAttr>());
|
||||
pm->AddPass(std::make_shared<opt::ConfusionSoftmaxGradRule>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
|
Loading…
Reference in New Issue