diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc index 4eb055829f..f15069ea25 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include "backend/session/anf_runtime_algorithm.h" #include "backend/optimizer/common/helper.h" @@ -73,26 +75,7 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node); auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node); MS_EXCEPTION_IF_NULL(used_cast_node_list); - std::unordered_map format_counter; - for (const auto &node_info : *used_cast_node_list) { - MS_EXCEPTION_IF_NULL(node_info.first); - auto cast_out_node = node_info.first->cast(); - MS_EXCEPTION_IF_NULL(cast_out_node); - size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node); - for (size_t index = 0; index < input_num; ++index) { - if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast(), index), 0).first != - cast_node) { - continue; - } - auto format = AnfAlgo::GetInputFormat(cast_out_node, index); - auto it = format_counter.find(format); - if (it == format_counter.end()) { - format_counter[format] = 1; - } else { - it->second++; - } - } - } + std::unordered_map format_counter = CalculateFormat(used_cast_node_list, cast_node); auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0); string convert_format = kOpFormat_DEFAULT; if (cast_input_format == kOpFormat_DEFAULT) { @@ -121,5 +104,33 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr SetCastFormat(cast_node, convert_format); } } + +std::unordered_map ConvertCastFormat::CalculateFormat( + const std::shared_ptr>> &used_cast_node_list, + const CNodePtr &cast_node) const { + MS_EXCEPTION_IF_NULL(used_cast_node_list); + MS_EXCEPTION_IF_NULL(cast_node); + std::unordered_map format_counter; + for (const auto &node_info : *used_cast_node_list) { + MS_EXCEPTION_IF_NULL(node_info.first); + auto cast_out_node = node_info.first->cast(); + MS_EXCEPTION_IF_NULL(cast_out_node); + size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node); + for (size_t index = 0; index < input_num; ++index) { + if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast(), index), 0).first != + cast_node) { + continue; + } + auto format = AnfAlgo::GetInputFormat(cast_out_node, index); + auto it = format_counter.find(format); + if (it == format_counter.end()) { + format_counter[format] = 1; + } else { + it->second++; + } + } + } + return format_counter; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.h index 5c9d2f22ce..31b046037d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.h @@ -17,6 +17,10 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_ #include +#include +#include +#include +#include #include "backend/optimizer/common/optimizer.h" @@ -30,6 +34,9 @@ class ConvertCastFormat : public PatternProcessPass { const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override; private: + std::unordered_map CalculateFormat( + const std::shared_ptr>> &used_cast_node_list, + const CNodePtr &cast_node) const; void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const; void SetCastFormat(const CNodePtr &cast_node, const string &format) const; }; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc index 70bea4a025..ef44c47d51 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -143,18 +143,22 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_ } // insert depend if (origin_format != cur_format || origin_type != cur_type) { - std::vector depend_nodes; - if (get_item.get() != nullptr) { - depend_nodes = std::vector{NewValueNode(prim::kPrimDepend), get_item, final_node}; - } else { - depend_nodes = std::vector{NewValueNode(prim::kPrimDepend), cnode, final_node}; - } - final_node = func_graph->NewCNode(depend_nodes); + final_node = MakeDependency(get_item, final_node, cnode, func_graph); MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString(); } - return final_node; } + +CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node, + const CNodePtr &cnode, const FuncGraphPtr &func_graph) const { + std::vector depend_nodes; + if (get_item != nullptr) { + depend_nodes = std::vector{NewValueNode(prim::kPrimDepend), get_item, final_node}; + } else { + depend_nodes = std::vector{NewValueNode(prim::kPrimDepend), cnode, final_node}; + } + return func_graph->NewCNode(depend_nodes); +} CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::shared_ptr &op_info) const { MS_EXCEPTION_IF_NULL(op_info); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h index fe9bd3e93a..48c1674ade 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h @@ -33,6 +33,8 @@ class DealRefTransAndCast : public TransDataSplit { const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: + CNodePtr MakeDependency(const CNodePtr &getitem, const CNodePtr &final_node, const CNodePtr &cnode, + const FuncGraphPtr &func_graph) const; CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index d5ea814f29..b6976c8096 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -97,7 +97,6 @@ const std::map kTransFormatMapOfHostToDevice{ {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}}; - } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 36162e65e7..5a5d67f351 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -182,32 +182,36 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj } } +BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const { + SyncData(py_args); + auto size = py_args.size(); + py::tuple input_args(size - 2); + for (size_t i = 0; i < size - 2; ++i) { + input_args[i] = py_args[i]; + } + py::tuple convert_args(py_args.size()); + ConvertCTensorToPyTensor(py_args, &convert_args); + auto inst = pynative::PynativeExecutor::GetInstance(); + MS_EXCEPTION_IF_NULL(inst); + try { + MS_LOG(DEBUG) << "Run bprop function start"; + inst->NewGraph(hook_, input_args.cast()); + py::object grads_obj = hook_(*convert_args); + py::tuple grads = check_bprop_out(grads_obj, py_args); + inst->EndGraph(hook_, grads_obj, input_args.cast()); + MS_LOG(DEBUG) << "Run bprop function end"; + return std::make_shared(grads); + } catch (std::exception &bt) { + inst->ClearRes(); + std::rethrow_exception(std::current_exception()); + } +} + BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { py::tuple py_args = ConvertDatatoPyTuple(args); bool is_bprop = this->HasAttr(kBpropAttrName); if (is_bprop) { - SyncData(py_args); - auto size = py_args.size(); - py::tuple input_args(size - 2); - for (size_t i = 0; i < size - 2; ++i) { - input_args[i] = py_args[i]; - } - py::tuple convert_args(py_args.size()); - ConvertCTensorToPyTensor(py_args, &convert_args); - auto inst = pynative::PynativeExecutor::GetInstance(); - MS_EXCEPTION_IF_NULL(inst); - try { - MS_LOG(DEBUG) << "Run bprop function start"; - inst->NewGraph(hook_, input_args.cast()); - py::object grads_obj = hook_(*convert_args); - py::tuple grads = check_bprop_out(grads_obj, py_args); - inst->EndGraph(hook_, grads_obj, input_args.cast()); - MS_LOG(DEBUG) << "Run bprop function end"; - return std::make_shared(grads); - } catch (std::exception &bt) { - inst->ClearRes(); - std::rethrow_exception(std::current_exception()); - } + return RunBpropHookFunction(py_args); } SyncData(py_args[2]); bool is_cell = this->HasAttr(kCellHookAttrName); diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.h b/mindspore/ccsrc/pybind_api/ir/primitive_py.h index 6cdf46fe2a..21455cc3ce 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.h +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.h @@ -55,6 +55,7 @@ class PrimitivePy : public Primitive { void set_hook(const py::function &hook) { hook_ = hook; } py::function hook() const { return hook_; } BaseRef RunHookFunction(const VectorRef &args) const override; + BaseRef RunBpropHookFunction(const py::tuple &py_args) const; BaseRef RunComputeFunction(const VectorRef &args) const override; py::object RunPyComputeFunction(const py::tuple &py_args) const; bool HasComputeFunction() const; diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index f7cf304906..50e4ce1108 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -28,29 +28,8 @@ namespace mindspore { namespace ops { namespace { -abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); - auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); - auto format = Format(GetValue(primitive->GetAttr(kFormat))); - if (format == NHWC) { - x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; - w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; - } - - CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); - CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); - CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue(primitive->GetAttr(kGroup)), kEqual, - "w_shape[1]", w_shape[1], prim_name); - auto out_channel = GetValue(primitive->GetAttr(kOutChannel)); - CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); - std::vector temp_w; - std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); - CheckAndConvertUtils::Check("kernel_size", GetValue>(primitive->GetAttr(kKernelSize)), kEqual, - "w_shape[2:4]", temp_w, prim_name); - +std::vector SetPadList(const PrimitivePtr &primitive, const std::vector &w_shape, + const std::vector &x_shape, const int64_t &out_channel) { auto kernel_size_h = w_shape[2]; auto kernel_size_w = w_shape[3]; auto stride = GetValue>(primitive->GetAttr(kStride)); @@ -92,13 +71,36 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve h_out = floor(h_out); w_out = floor(w_out); } - CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); - primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name))); + CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name()); + primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name()))); std::vector out_shape = {x_shape[0], out_channel, h_out, w_out}; + return out_shape; +} +abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name); + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name); + auto format = Format(GetValue(primitive->GetAttr(kFormat))); if (format == NHWC) { - out_shape = {x_shape[0], h_out, w_out, out_channel}; + x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]}; + w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]}; + } + CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); + CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); + CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue(primitive->GetAttr(kGroup)), kEqual, + "w_shape[1]", w_shape[1], prim_name); + auto out_channel = GetValue(primitive->GetAttr(kOutChannel)); + CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name); + std::vector temp_w; + std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); + CheckAndConvertUtils::Check("kernel_size", GetValue>(primitive->GetAttr(kKernelSize)), kEqual, + "w_shape[2:4]", temp_w, prim_name); + auto out_shape = SetPadList(primitive, w_shape, x_shape, out_channel); + if (format == NHWC) { + out_shape = {out_shape[0], out_shape[3], out_shape[1], out_shape[2]}; } - return std::make_shared(out_shape); } diff --git a/mindspore/core/ops/grad/conv2d_backprop_input.cc b/mindspore/core/ops/grad/conv2d_backprop_input.cc index cfddf16230..ac68700a15 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_input.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_input.cc @@ -23,32 +23,8 @@ namespace mindspore { namespace ops { -AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - auto doutput = input_args[0]; - auto x_size = input_args[2]; - auto x_size_value = x_size->GetValueTrack(); - MS_EXCEPTION_IF_NULL(x_size); - auto x_size_v = GetValue>(x_size_value); - // infer dtype - auto dtype = doutput->BuildType(); - if (!dtype->isa()) { - MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); - } - auto input_tensor_type = dtype->cast(); - MS_EXCEPTION_IF_NULL(input_tensor_type); - auto element = input_tensor_type->element(); - // infer shape - auto dout_shape = doutput->BuildShape(); - MS_EXCEPTION_IF_NULL(doutput); - auto dout_shapeptr = dout_shape->cast(); - auto dout_shape_norm = dout_shapeptr->shape(); +void SetPadList(const PrimitivePtr &primitive, const std::vector &dout_shape_norm, + const std::vector &x_size_v) { auto kernel_size = GetValue>(primitive->GetAttr(kKernelSize)); auto stride = GetValue>(primitive->GetAttr(kStride)); auto dilation = GetValue>(primitive->GetAttr(kStride)); @@ -76,6 +52,34 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co pad_list = GetValue>(primitive->GetAttr(kPad)); } primitive->AddAttr(kPadList, MakeValue(pad_list)); +} +AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto doutput = input_args[0]; + auto x_size = input_args[2]; + auto x_size_value = x_size->GetValueTrack(); + MS_EXCEPTION_IF_NULL(x_size); + auto x_size_v = GetValue>(x_size_value); + // infer dtype + auto dtype = doutput->BuildType(); + if (!dtype->isa()) { + MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString(); + } + auto input_tensor_type = dtype->cast(); + MS_EXCEPTION_IF_NULL(input_tensor_type); + auto element = input_tensor_type->element(); + // infer shape + auto dout_shape = doutput->BuildShape(); + MS_EXCEPTION_IF_NULL(doutput); + auto dout_shapeptr = dout_shape->cast(); + auto dout_shape_norm = dout_shapeptr->shape(); + SetPadList(primitive, dout_shape_norm, x_size_v); return std::make_shared(element, std::make_shared(x_size_v)); } diff --git a/mindspore/core/ops/max_pool.cc b/mindspore/core/ops/max_pool.cc index 22e0b464da..f32b7466ed 100644 --- a/mindspore/core/ops/max_pool.cc +++ b/mindspore/core/ops/max_pool.cc @@ -102,8 +102,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector>(primitive->GetAttr(kKernelSize)); auto pad_mode_value = (primitive->GetAttr(kPadMode)); - PadMode pad_mode = PAD; - pad_mode = PadMode(GetValue(pad_mode_value)); + PadMode pad_mode = PadMode(GetValue(pad_mode_value)); auto batch = in_shape[0]; auto channel = in_shape[1]; auto in_h = in_shape[2];