From: @lianliguang
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2021-03-31 16:34:27 +08:00 committed by Gitee
commit 36dbb2690e
10 changed files with 139 additions and 106 deletions

View File

@ -18,6 +18,8 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
@ -73,26 +75,7 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node);
auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node);
MS_EXCEPTION_IF_NULL(used_cast_node_list);
std::unordered_map<string, size_t> format_counter;
for (const auto &node_info : *used_cast_node_list) {
MS_EXCEPTION_IF_NULL(node_info.first);
auto cast_out_node = node_info.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_out_node);
size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node);
for (size_t index = 0; index < input_num; ++index) {
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
cast_node) {
continue;
}
auto format = AnfAlgo::GetInputFormat(cast_out_node, index);
auto it = format_counter.find(format);
if (it == format_counter.end()) {
format_counter[format] = 1;
} else {
it->second++;
}
}
}
std::unordered_map<string, size_t> format_counter = CalculateFormat(used_cast_node_list, cast_node);
auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0);
string convert_format = kOpFormat_DEFAULT;
if (cast_input_format == kOpFormat_DEFAULT) {
@ -121,5 +104,33 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr
SetCastFormat(cast_node, convert_format);
}
}
std::unordered_map<string, size_t> ConvertCastFormat::CalculateFormat(
const std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> &used_cast_node_list,
const CNodePtr &cast_node) const {
MS_EXCEPTION_IF_NULL(used_cast_node_list);
MS_EXCEPTION_IF_NULL(cast_node);
std::unordered_map<string, size_t> format_counter;
for (const auto &node_info : *used_cast_node_list) {
MS_EXCEPTION_IF_NULL(node_info.first);
auto cast_out_node = node_info.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_out_node);
size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node);
for (size_t index = 0; index < input_num; ++index) {
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
cast_node) {
continue;
}
auto format = AnfAlgo::GetInputFormat(cast_out_node, index);
auto it = format_counter.find(format);
if (it == format_counter.end()) {
format_counter[format] = 1;
} else {
it->second++;
}
}
}
return format_counter;
}
} // namespace opt
} // namespace mindspore

View File

@ -17,6 +17,10 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_
#include <string>
#include <unordered_map>
#include <utility>
#include <memory>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
@ -30,6 +34,9 @@ class ConvertCastFormat : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override;
private:
std::unordered_map<string, size_t> CalculateFormat(
const std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> &used_cast_node_list,
const CNodePtr &cast_node) const;
void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const;
void SetCastFormat(const CNodePtr &cast_node, const string &format) const;
};

View File

@ -143,18 +143,22 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_
}
// insert depend
if (origin_format != cur_format || origin_type != cur_type) {
std::vector<AnfNodePtr> depend_nodes;
if (get_item.get() != nullptr) {
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
} else {
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node};
}
final_node = func_graph->NewCNode(depend_nodes);
final_node = MakeDependency(get_item, final_node, cnode, func_graph);
MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString();
}
return final_node;
}
CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
const CNodePtr &cnode, const FuncGraphPtr &func_graph) const {
std::vector<AnfNodePtr> depend_nodes;
if (get_item != nullptr) {
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
} else {
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node};
}
return func_graph->NewCNode(depend_nodes);
}
CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::shared_ptr<kernel::OpInfo> &op_info) const {
MS_EXCEPTION_IF_NULL(op_info);

View File

@ -33,6 +33,8 @@ class DealRefTransAndCast : public TransDataSplit {
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
CNodePtr MakeDependency(const CNodePtr &getitem, const CNodePtr &final_node, const CNodePtr &cnode,
const FuncGraphPtr &func_graph) const;
CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,

View File

@ -97,7 +97,6 @@ const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}};
} // namespace trans
} // namespace mindspore

View File

@ -182,32 +182,36 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
}
}
BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const {
SyncData(py_args);
auto size = py_args.size();
py::tuple input_args(size - 2);
for (size_t i = 0; i < size - 2; ++i) {
input_args[i] = py_args[i];
}
py::tuple convert_args(py_args.size());
ConvertCTensorToPyTensor(py_args, &convert_args);
auto inst = pynative::PynativeExecutor::GetInstance();
MS_EXCEPTION_IF_NULL(inst);
try {
MS_LOG(DEBUG) << "Run bprop function start";
inst->NewGraph(hook_, input_args.cast<py::args>());
py::object grads_obj = hook_(*convert_args);
py::tuple grads = check_bprop_out(grads_obj, py_args);
inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
MS_LOG(DEBUG) << "Run bprop function end";
return std::make_shared<PyObjectRef>(grads);
} catch (std::exception &bt) {
inst->ClearRes();
std::rethrow_exception(std::current_exception());
}
}
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
py::tuple py_args = ConvertDatatoPyTuple(args);
bool is_bprop = this->HasAttr(kBpropAttrName);
if (is_bprop) {
SyncData(py_args);
auto size = py_args.size();
py::tuple input_args(size - 2);
for (size_t i = 0; i < size - 2; ++i) {
input_args[i] = py_args[i];
}
py::tuple convert_args(py_args.size());
ConvertCTensorToPyTensor(py_args, &convert_args);
auto inst = pynative::PynativeExecutor::GetInstance();
MS_EXCEPTION_IF_NULL(inst);
try {
MS_LOG(DEBUG) << "Run bprop function start";
inst->NewGraph(hook_, input_args.cast<py::args>());
py::object grads_obj = hook_(*convert_args);
py::tuple grads = check_bprop_out(grads_obj, py_args);
inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
MS_LOG(DEBUG) << "Run bprop function end";
return std::make_shared<PyObjectRef>(grads);
} catch (std::exception &bt) {
inst->ClearRes();
std::rethrow_exception(std::current_exception());
}
return RunBpropHookFunction(py_args);
}
SyncData(py_args[2]);
bool is_cell = this->HasAttr(kCellHookAttrName);

View File

@ -55,6 +55,7 @@ class PrimitivePy : public Primitive {
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }
BaseRef RunHookFunction(const VectorRef &args) const override;
BaseRef RunBpropHookFunction(const py::tuple &py_args) const;
BaseRef RunComputeFunction(const VectorRef &args) const override;
py::object RunPyComputeFunction(const py::tuple &py_args) const;
bool HasComputeFunction() const;

View File

@ -28,29 +28,8 @@
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
}
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual,
"w_shape[1]", w_shape[1], prim_name);
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
std::vector<int64_t> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
"w_shape[2:4]", temp_w, prim_name);
std::vector<int64_t> SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &w_shape,
const std::vector<int64_t> &x_shape, const int64_t &out_channel) {
auto kernel_size_h = w_shape[2];
auto kernel_size_w = w_shape[3];
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
@ -92,13 +71,36 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
h_out = floor(h_out);
w_out = floor(w_out);
}
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name);
primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name)));
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name());
primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name())));
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
return out_shape;
}
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
if (format == NHWC) {
out_shape = {x_shape[0], h_out, w_out, out_channel};
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
}
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual,
"w_shape[1]", w_shape[1], prim_name);
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
std::vector<int64_t> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
"w_shape[2:4]", temp_w, prim_name);
auto out_shape = SetPadList(primitive, w_shape, x_shape, out_channel);
if (format == NHWC) {
out_shape = {out_shape[0], out_shape[3], out_shape[1], out_shape[2]};
}
return std::make_shared<abstract::Shape>(out_shape);
}

View File

@ -23,32 +23,8 @@
namespace mindspore {
namespace ops {
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto doutput = input_args[0];
auto x_size = input_args[2];
auto x_size_value = x_size->GetValueTrack();
MS_EXCEPTION_IF_NULL(x_size);
auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value);
// infer dtype
auto dtype = doutput->BuildType();
if (!dtype->isa<TensorType>()) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString();
}
auto input_tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_tensor_type);
auto element = input_tensor_type->element();
// infer shape
auto dout_shape = doutput->BuildShape();
MS_EXCEPTION_IF_NULL(doutput);
auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>();
auto dout_shape_norm = dout_shapeptr->shape();
void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
const std::vector<int64_t> &x_size_v) {
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
@ -76,6 +52,34 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co
pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
}
primitive->AddAttr(kPadList, MakeValue(pad_list));
}
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto doutput = input_args[0];
auto x_size = input_args[2];
auto x_size_value = x_size->GetValueTrack();
MS_EXCEPTION_IF_NULL(x_size);
auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value);
// infer dtype
auto dtype = doutput->BuildType();
if (!dtype->isa<TensorType>()) {
MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString();
}
auto input_tensor_type = dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(input_tensor_type);
auto element = input_tensor_type->element();
// infer shape
auto dout_shape = doutput->BuildShape();
MS_EXCEPTION_IF_NULL(doutput);
auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>();
auto dout_shape_norm = dout_shapeptr->shape();
SetPadList(primitive, dout_shape_norm, x_size_v);
return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v));
}

View File

@ -102,8 +102,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
auto pad_mode_value = (primitive->GetAttr(kPadMode));
PadMode pad_mode = PAD;
pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
PadMode pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
auto batch = in_shape[0];
auto channel = in_shape[1];
auto in_h = in_shape[2];