!14448 fix codex
From: @lianliguang Reviewed-by: @zhoufeng54,@chujinjin Signed-off-by: @lilongfei15
This commit is contained in:
commit
42f6d71560
|
@ -18,6 +18,8 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
@ -73,26 +75,7 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr
|
|||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node);
|
||||
auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node);
|
||||
MS_EXCEPTION_IF_NULL(used_cast_node_list);
|
||||
std::unordered_map<string, size_t> format_counter;
|
||||
for (const auto &node_info : *used_cast_node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node_info.first);
|
||||
auto cast_out_node = node_info.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cast_out_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node);
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
|
||||
cast_node) {
|
||||
continue;
|
||||
}
|
||||
auto format = AnfAlgo::GetInputFormat(cast_out_node, index);
|
||||
auto it = format_counter.find(format);
|
||||
if (it == format_counter.end()) {
|
||||
format_counter[format] = 1;
|
||||
} else {
|
||||
it->second++;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::unordered_map<string, size_t> format_counter = CalculateFormat(used_cast_node_list, cast_node);
|
||||
auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0);
|
||||
string convert_format = kOpFormat_DEFAULT;
|
||||
if (cast_input_format == kOpFormat_DEFAULT) {
|
||||
|
@ -121,5 +104,33 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr
|
|||
SetCastFormat(cast_node, convert_format);
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<string, size_t> ConvertCastFormat::CalculateFormat(
|
||||
const std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> &used_cast_node_list,
|
||||
const CNodePtr &cast_node) const {
|
||||
MS_EXCEPTION_IF_NULL(used_cast_node_list);
|
||||
MS_EXCEPTION_IF_NULL(cast_node);
|
||||
std::unordered_map<string, size_t> format_counter;
|
||||
for (const auto &node_info : *used_cast_node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node_info.first);
|
||||
auto cast_out_node = node_info.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cast_out_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node);
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
|
||||
cast_node) {
|
||||
continue;
|
||||
}
|
||||
auto format = AnfAlgo::GetInputFormat(cast_out_node, index);
|
||||
auto it = format_counter.find(format);
|
||||
if (it == format_counter.end()) {
|
||||
format_counter[format] = 1;
|
||||
} else {
|
||||
it->second++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return format_counter;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
|
@ -30,6 +34,9 @@ class ConvertCastFormat : public PatternProcessPass {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::unordered_map<string, size_t> CalculateFormat(
|
||||
const std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> &used_cast_node_list,
|
||||
const CNodePtr &cast_node) const;
|
||||
void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const;
|
||||
void SetCastFormat(const CNodePtr &cast_node, const string &format) const;
|
||||
};
|
||||
|
|
|
@ -143,18 +143,23 @@ CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_
|
|||
}
|
||||
// insert depend
|
||||
if (origin_format != cur_format || origin_type != cur_type) {
|
||||
std::vector<AnfNodePtr> depend_nodes;
|
||||
if (get_item.get() != nullptr) {
|
||||
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
|
||||
} else {
|
||||
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node};
|
||||
}
|
||||
final_node = func_graph->NewCNode(depend_nodes);
|
||||
final_node = MakeDependency(get_item, final_node, cnode, func_graph);
|
||||
MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString();
|
||||
}
|
||||
|
||||
return final_node;
|
||||
}
|
||||
|
||||
CNodePtr DealRefTransAndCast::MakeDependency(const CNodePtr &get_item, const CNodePtr &final_node,
|
||||
const CNodePtr &cnode, const FuncGraphPtr &func_graph) const {
|
||||
std::vector<AnfNodePtr> depend_nodes;
|
||||
if (get_item != nullptr) {
|
||||
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), get_item, final_node};
|
||||
} else {
|
||||
depend_nodes = std::vector<AnfNodePtr>{NewValueNode(prim::kPrimDepend), cnode, final_node};
|
||||
}
|
||||
return func_graph->NewCNode(depend_nodes);
|
||||
}
|
||||
CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::shared_ptr<kernel::OpInfo> &op_info) const {
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
|
|
|
@ -33,6 +33,8 @@ class DealRefTransAndCast : public TransDataSplit {
|
|||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
CNodePtr MakeDependency(const CNodePtr &getitem, const CNodePtr &final_node, const CNodePtr &cnode,
|
||||
const FuncGraphPtr &func_graph) const;
|
||||
CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
|
||||
void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const;
|
||||
CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
|
|
|
@ -97,7 +97,6 @@ const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
|
|||
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
||||
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
|
||||
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}};
|
||||
|
||||
} // namespace trans
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -182,32 +182,36 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj
|
|||
}
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const {
|
||||
SyncData(py_args);
|
||||
auto size = py_args.size();
|
||||
py::tuple input_args(size - 2);
|
||||
for (size_t i = 0; i < size - 2; ++i) {
|
||||
input_args[i] = py_args[i];
|
||||
}
|
||||
py::tuple convert_args(py_args.size());
|
||||
ConvertCTensorToPyTensor(py_args, &convert_args);
|
||||
auto inst = pynative::PynativeExecutor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(inst);
|
||||
try {
|
||||
MS_LOG(DEBUG) << "Run bprop function start";
|
||||
inst->NewGraph(hook_, input_args.cast<py::args>());
|
||||
py::object grads_obj = hook_(*convert_args);
|
||||
py::tuple grads = check_bprop_out(grads_obj, py_args);
|
||||
inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
|
||||
MS_LOG(DEBUG) << "Run bprop function end";
|
||||
return std::make_shared<PyObjectRef>(grads);
|
||||
} catch (std::exception &bt) {
|
||||
inst->ClearRes();
|
||||
std::rethrow_exception(std::current_exception());
|
||||
}
|
||||
}
|
||||
|
||||
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
||||
py::tuple py_args = ConvertDatatoPyTuple(args);
|
||||
bool is_bprop = this->HasAttr(kBpropAttrName);
|
||||
if (is_bprop) {
|
||||
SyncData(py_args);
|
||||
auto size = py_args.size();
|
||||
py::tuple input_args(size - 2);
|
||||
for (size_t i = 0; i < size - 2; ++i) {
|
||||
input_args[i] = py_args[i];
|
||||
}
|
||||
py::tuple convert_args(py_args.size());
|
||||
ConvertCTensorToPyTensor(py_args, &convert_args);
|
||||
auto inst = pynative::PynativeExecutor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(inst);
|
||||
try {
|
||||
MS_LOG(DEBUG) << "Run bprop function start";
|
||||
inst->NewGraph(hook_, input_args.cast<py::args>());
|
||||
py::object grads_obj = hook_(*convert_args);
|
||||
py::tuple grads = check_bprop_out(grads_obj, py_args);
|
||||
inst->EndGraph(hook_, grads_obj, input_args.cast<py::args>());
|
||||
MS_LOG(DEBUG) << "Run bprop function end";
|
||||
return std::make_shared<PyObjectRef>(grads);
|
||||
} catch (std::exception &bt) {
|
||||
inst->ClearRes();
|
||||
std::rethrow_exception(std::current_exception());
|
||||
}
|
||||
return RunBpropHookFunction(py_args);
|
||||
}
|
||||
SyncData(py_args[2]);
|
||||
bool is_cell = this->HasAttr(kCellHookAttrName);
|
||||
|
|
|
@ -55,6 +55,7 @@ class PrimitivePy : public Primitive {
|
|||
void set_hook(const py::function &hook) { hook_ = hook; }
|
||||
py::function hook() const { return hook_; }
|
||||
BaseRef RunHookFunction(const VectorRef &args) const override;
|
||||
BaseRef RunBpropHookFunction(const py::tuple &py_args) const;
|
||||
BaseRef RunComputeFunction(const VectorRef &args) const override;
|
||||
py::object RunPyComputeFunction(const py::tuple &py_args) const;
|
||||
bool HasComputeFunction() const;
|
||||
|
|
|
@ -28,29 +28,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
|
||||
}
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual,
|
||||
"w_shape[1]", w_shape[1], prim_name);
|
||||
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
|
||||
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
|
||||
std::vector<int64_t> temp_w;
|
||||
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
||||
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
|
||||
"w_shape[2:4]", temp_w, prim_name);
|
||||
|
||||
std::vector<int64_t> SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &w_shape,
|
||||
const std::vector<int64_t> &x_shape, const int64_t &out_channel) {
|
||||
auto kernel_size_h = w_shape[2];
|
||||
auto kernel_size_w = w_shape[3];
|
||||
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
|
||||
|
@ -92,14 +71,36 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
h_out = floor(h_out);
|
||||
w_out = floor(w_out);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name);
|
||||
primitive->AddAttr(kPadList,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true)));
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name());
|
||||
primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name())));
|
||||
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
|
||||
return out_shape;
|
||||
}
|
||||
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
if (format == NHWC) {
|
||||
out_shape = {x_shape[0], h_out, w_out, out_channel};
|
||||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual,
|
||||
"w_shape[1]", w_shape[1], prim_name);
|
||||
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
|
||||
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
|
||||
std::vector<int64_t> temp_w;
|
||||
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
||||
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
|
||||
"w_shape[2:4]", temp_w, prim_name);
|
||||
auto out_shape = SetPadList(primitive, w_shape, x_shape, out_channel);
|
||||
if (format == NHWC) {
|
||||
out_shape = {out_shape[0], out_shape[3], out_shape[1], out_shape[2]};
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -23,32 +23,8 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
|
||||
for (auto item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto doutput = input_args[0];
|
||||
auto x_size = input_args[2];
|
||||
auto x_size_value = x_size->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(x_size);
|
||||
auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value);
|
||||
// infer dtype
|
||||
auto dtype = doutput->BuildType();
|
||||
if (!dtype->isa<TensorType>()) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString();
|
||||
}
|
||||
auto input_tensor_type = dtype->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tensor_type);
|
||||
auto element = input_tensor_type->element();
|
||||
// infer shape
|
||||
auto dout_shape = doutput->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(doutput);
|
||||
auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>();
|
||||
auto dout_shape_norm = dout_shapeptr->shape();
|
||||
void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
|
||||
const std::vector<int64_t> &x_size_v) {
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto stride = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
|
||||
auto dilation = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStride));
|
||||
|
@ -76,6 +52,34 @@ AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, co
|
|||
pad_list = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
|
||||
}
|
||||
primitive->AddAttr(kPadList, MakeValue(pad_list));
|
||||
}
|
||||
AbstractBasePtr Conv2DBackpropInputInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input number", input_args.size(), kEqual, 3, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto doutput = input_args[0];
|
||||
auto x_size = input_args[2];
|
||||
auto x_size_value = x_size->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(x_size);
|
||||
auto x_size_v = GetValue<std::vector<int64_t>>(x_size_value);
|
||||
// infer dtype
|
||||
auto dtype = doutput->BuildType();
|
||||
if (!dtype->isa<TensorType>()) {
|
||||
MS_LOG(EXCEPTION) << "Conv2DBackpropInputInfer doutput must be tensor but got" << dtype->ToString();
|
||||
}
|
||||
auto input_tensor_type = dtype->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tensor_type);
|
||||
auto element = input_tensor_type->element();
|
||||
// infer shape
|
||||
auto dout_shape = doutput->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(doutput);
|
||||
auto dout_shapeptr = dout_shape->cast<abstract::ShapePtr>();
|
||||
auto dout_shape_norm = dout_shapeptr->shape();
|
||||
SetPadList(primitive, dout_shape_norm, x_size_v);
|
||||
return std::make_shared<abstract::AbstractTensor>(element, std::make_shared<abstract::Shape>(x_size_v));
|
||||
}
|
||||
|
||||
|
|
|
@ -103,8 +103,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto pad_mode_value = (primitive->GetAttr(kPadMode));
|
||||
PadMode pad_mode = PAD;
|
||||
pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
|
||||
PadMode pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
|
||||
auto batch = in_shape[0];
|
||||
auto channel = in_shape[1];
|
||||
auto in_h = in_shape[2];
|
||||
|
|
Loading…
Reference in New Issue