fix some bugs about pynative dynamic shape on ascend

This commit is contained in:
hanhuifeng2020 2022-12-27 17:30:23 +08:00
parent f8bc59aedb
commit 5559fe3e1a
12 changed files with 96 additions and 40 deletions

View File

@ -17,6 +17,7 @@
#include <vector>
#include <map>
#include <set>
#include "runtime/rt.h"
#include "ir/tensor.h"
#include "include/common/utils/anfalgo.h"
@ -68,11 +69,41 @@ int AclKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector
return 0;
}
bool IsReduceOp(const std::string &op_type) {
const std::set<std::string> reduce_op_type = {prim::kPrimReduceAll->name(), prim::kPrimReduceAny->name(),
prim::kPrimReduceMean->name(), prim::kPrimReduceMax->name(),
prim::kPrimReduceMin->name(), prim::kPrimReduceProd->name(),
prim::kPrimReduceSum->name(), prim::kPrimSquareSumV1->name()};
if (reduce_op_type.count(op_type)) {
return true;
}
return false;
}
void AclKernelMod::UpdateReduceAxisAttr(const AnfNodePtr &node) {
if (!IsReduceOp(op_type_)) {
return;
}
if (!common::AnfAlgo::HasNodeAttr("axis", node->cast<CNodePtr>())) {
return;
}
ShapeVector axes = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "axis");
if (!axes.empty()) {
return;
}
auto in_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
for (size_t i = 0; i < in_shape.size(); ++i) {
axes.push_back(i);
}
common::AnfAlgo::SetNodeAttr("axis", MakeValue(axes), node);
}
void AclKernelMod::ProcessAttribute(const std::shared_ptr<AclOpDesc> &op_desc_ptr) {
auto node = anf_node_.lock();
MS_EXCEPTION_IF_NULL(node);
const auto &attr_to_input_maps = GeOpConvertor::GetNeedAddInput(node, true);
const auto &input_names = kernel::AclUtils::GetOpInputAnchorNames(node);
UpdateReduceAxisAttr(node);
auto attr_list = GeOpConvertor::GetAttrAndValue(node, true);
for (auto &[attr_name, value] : attr_list) {
if (value == nullptr) {

View File

@ -48,6 +48,7 @@ class AclKernelMod : public AscendKernelMod {
protected:
void SyncData() override;
void ProcessAttribute(const std::shared_ptr<AclOpDesc> &op_desc_ptr);
void UpdateReduceAxisAttr(const AnfNodePtr &node);
private:
std::vector<GeTensorDescPtr> input_desc_list_{};

View File

@ -28,6 +28,7 @@
#include "backend/common/session/anf_runtime_algorithm.h"
#include "plugin/device/ascend/hal/device/ge_types_convert.h"
#include "plugin/device/ascend/optimizer/ascend_helper.h"
namespace mindspore {
namespace kernel {
@ -675,6 +676,11 @@ std::vector<GeTensorDescPtr> AclUtils::GetInputTensorDesc(const AnfNodePtr &anf_
auto input_shape = AnfAlgo::GetOutputDeviceShape(input, idx);
auto input_format = AnfAlgo::GetOutputFormat(input, idx);
auto ori_format = IsOneOf3DFormat(input_format) ? kOpFormat_NCDHW : kOpFormat_DEFAULT;
if (!opt::NeedInsertTransData(ori_shape, input_format)) {
MS_LOG_DEBUG << "Set format of " << anf_node->fullname_with_scope() << " to origin format";
input_shape = ori_shape;
input_format = ori_format;
}
ori_shape = UpdateShape(ori_shape, input_format, anf_node);
auto input_desc = GeOpConvertor::GetTensorDesc(input_shape, input_type, input_format, ori_shape, ori_format);
MS_EXCEPTION_IF_NULL(input_desc);
@ -703,6 +709,11 @@ std::vector<GeTensorDescPtr> AclUtils::GetOutputTensorDesc(const AnfNodePtr &anf
auto output_shape = AnfAlgo::GetOutputDeviceShape(anf_node, i);
auto output_format = AnfAlgo::GetOutputFormat(anf_node, i);
auto ori_format = IsOneOf3DFormat(output_format) ? kOpFormat_NCDHW : kOpFormat_DEFAULT;
if (!opt::NeedInsertTransData(ori_shape, output_format)) {
MS_LOG_DEBUG << "Set format of " << anf_node->fullname_with_scope() << " to origin format";
output_shape = ori_shape;
output_format = ori_format;
}
ori_shape = UpdateShape(ori_shape, output_format, anf_node);
auto output_desc = GeOpConvertor::GetTensorDesc(output_shape, output_type, output_format, ori_shape, ori_format);
MS_EXCEPTION_IF_NULL(output_desc);

View File

@ -36,12 +36,6 @@ namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
bool NeedInsertTransData(const ShapeVector &origin_shape, const std::string &format) {
bool shape_check =
origin_shape.size() > 1 || (origin_shape.size() == 1 && origin_shape[0] % SizeToLong(kCubeSize) != 0);
return kCommonFormatSet.find(format) == kCommonFormatSet.end() && (shape_check || format == kOpFormat_ND_RNN_BIAS);
}
std::string GetTransOpName(const std::string &spec_format) {
std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
? prim::kPrimTransDataRNN->name()

View File

@ -120,6 +120,12 @@ void SetInputOutputNames(const std::vector<std::string> &input_names, const std:
void SelectCallInlineKernelInfo(const CNodePtr &node);
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
inline bool NeedInsertTransData(const ShapeVector &origin_shape, const std::string &format) {
bool shape_check =
origin_shape.size() > 1 || (origin_shape.size() == 1 && origin_shape[0] % SizeToLong(kCubeSize) != 0);
return kCommonFormatSet.find(format) == kCommonFormatSet.end() && (shape_check || format == kOpFormat_ND_RNN_BIAS);
}
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_

View File

@ -98,7 +98,7 @@ bool RunOpInsertTransData::ConvertNodeFormat(const FuncGraphPtr &graph, const An
MS_EXCEPTION_IF_NULL(cnode);
bool changed = false;
// convert the format of node to default
if (kCommonFormatSet.find(format) == kCommonFormatSet.end() && (input_size_ > 1 || format == kOpFormat_ND_RNN_BIAS)) {
if (NeedInsertTransData(input_shape_, format)) {
auto input_node = (!is_insert) ? common::AnfAlgo::GetInputNode(cnode, input_index) : node;
auto trans_node = AddTransOpNodeToGraph(graph, input_node, kernel_select_, insert_index, is_insert);
common::AnfAlgo::SetNodeInput(cnode, trans_node, input_index);
@ -121,7 +121,7 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) {
for (size_t index = 0; index < input_num; ++index) {
auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(node, index);
auto prev_node_out_infer_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, index);
input_size_ = prev_node_out_infer_shape.size();
input_shape_ = prev_node_out_infer_shape;
auto input_format = AnfAlgo::GetInputFormat(node, index);
// convert the format of node's input or output
auto input_changed = ConvertNodeFormat(graph, node, prev_input_format, 0, index, false);

View File

@ -41,7 +41,7 @@ class RunOpInsertTransData : public Pass {
bool ConvertNodeFormat(const FuncGraphPtr &graph, const AnfNodePtr &node, const std::string &format,
size_t insert_index, size_t input_index, bool is_insert) const;
KernelSelectPtr kernel_select_;
size_t input_size_{0};
ShapeVector input_shape_;
};
} // namespace opt
} // namespace mindspore

View File

@ -106,7 +106,7 @@ OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run
std::vector<KernelWithIndex> outputs_with_index;
for (auto &node : output_nodes) {
MS_EXCEPTION_IF_NULL(node);
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
(void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernel(node, 0));
}
AnfAlgo::UpdateGraphValidRefPair(graph);

View File

@ -60,14 +60,13 @@ OUTPUT_MAP(ReduceAny) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceAny, kNameReduceAny, ADPT_DESC(ReduceAny))
REG_ADPT_DESC(ReduceAnyD, kNameReduceAnyD, ADPT_DESC(ReduceAny))
// ReduceSumD
INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(ReduceSumD) = {
{2, ATTR_DESC(axes, AnyTraits<std::vector<int64_t>>(), AnyTraits<std::vector<int64_t>>())}};
ATTR_MAP(ReduceSumD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceSumD) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceSum, prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSumD))
REG_ADPT_DESC(ReduceSumD, prim::kPrimReduceSumD->name(), ADPT_DESC(ReduceSumD))
// ReduceSum
INPUT_MAP(ReduceSum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}};
ATTR_INPUT_MAP(ReduceSum) = {{"axis", "axes"}};
ATTR_MAP(ReduceSum) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits<bool>())}};
OUTPUT_MAP(ReduceSum) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ReduceSum, prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSum))
REG_ADPT_DESC(ReduceSumD, prim::kPrimReduceSumD->name(), ADPT_DESC(ReduceSum))
// ReduceProdD
INPUT_MAP(ReduceProdD) = {{1, INPUT_DESC(x)}};

View File

@ -58,9 +58,8 @@ DECLARE_OP_USE_OUTPUT(BNTrainingUpdate)
DECLARE_OP_ADAPTER(BNTrainingUpdateGrad)
DECLARE_OP_USE_OUTPUT(BNTrainingUpdateGrad)
DECLARE_OP_ADAPTER(ReduceSumD)
DECLARE_OP_USE_INPUT_ATTR(ReduceSumD)
DECLARE_OP_USE_OUTPUT(ReduceSumD)
DECLARE_OP_ADAPTER(ReduceSum)
DECLARE_OP_USE_OUTPUT(ReduceSum)
DECLARE_OP_ADAPTER(ReduceAny)
DECLARE_OP_USE_OUTPUT(ReduceAny)

View File

@ -62,9 +62,9 @@ def test_net():
[0, -2, -4, -7],
[-3, -2, -3, -16]]]]).astype(np.float16))
operator = Net()
output = operator(x, out)
operator.set_inputs(Tensor(shape=[None, 1, 6, 6], dtype=mstype.float16),
Tensor(shape=[None, 1, 4, 4], dtype=mstype.float16))
output = operator(x, out)
expect_out = np.array(
[[[[-60., -142., -265.], [-104., -211., -322.], [-102., -144., -248.]]]]).astype(np.float16)
assert np.allclose(output.asnumpy(), expect_out, 1e-3, 1e-3)

View File

@ -21,22 +21,32 @@ import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class NetDyn(nn.Cell):
def __init__(self, reduction, indices):
super(NetDyn, self).__init__()
self.indices = indices
self.unique = P.Unique()
self.gather = P.Gather()
class Net(nn.Cell):
def __init__(self, reduction):
super(Net, self).__init__()
self.loss = P.BCEWithLogitsLoss(reduction=reduction)
def construct(self, predict, target, weight, pos_weight):
unique_indice, _ = self.unique(self.indices)
predict = self.gather(predict, unique_indice, 0)
return self.loss(predict, target, weight, pos_weight)
def net_run():
predict = Tensor(np.arange(6).reshape(2, 3).astype(np.float32))
target = Tensor(np.arange(34, 40).reshape(2, 3).astype(np.float32))
weight = Tensor(np.array([2, 3, 1]).astype(np.float32))
pos_weight = Tensor(np.array([6, 3, 4]).astype(np.float32))
net = Net("mean")
net.set_inputs(Tensor(shape=[None, None], dtype=mstype.float32),
Tensor(target), Tensor(weight), Tensor(pos_weight))
output = net(predict, target, weight, pos_weight)
expected = -113.55404
# assert scalar
assert math.isclose(output.asnumpy().tolist(), expected, rel_tol=1e-4, abs_tol=1e-4)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -48,13 +58,18 @@ def test_bce_mean_dyn_ascend():
Expectation: Assert that results are consistent with expect.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
predict = Tensor(np.arange(6).reshape(2, 3).astype(np.float32))
target = Tensor(np.arange(34, 40).reshape(2, 3).astype(np.float32))
weight = Tensor(np.array([2, 3, 1]).astype(np.float32))
pos_weight = Tensor(np.array([6, 3, 4]).astype(np.float32))
indices = Tensor(np.array([0, 1]))
loss = NetDyn("mean", indices)
output = loss(predict, target, weight, pos_weight)
expected = -113.55404
# assert scalar
assert math.isclose(output.asnumpy().tolist(), expected, rel_tol=1e-4, abs_tol=1e-4)
net_run()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_bce_mean_dyn_ascend_pynative():
"""
Feature: Test dynamic shape of BCEWithLogitsLoss op that the reduction is mean on ascend.
Description: The shape of input is dynamic.
Expectation: Assert that results are consistent with expect.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
net_run()