!8920 Adapt ops LinSpace for Ascend.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@linqingke,@liangchenghui
Signed-off-by: @liangchenghui,@liangchenghui
This commit is contained in:
mindspore-ci-bot 2020-11-25 12:03:56 +08:00 committed by Gitee
commit 22d683a805
11 changed files with 220 additions and 51 deletions

View File

@ -51,6 +51,7 @@
#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h" #include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h"
#include "backend/optimizer/ascend/ir_fission/transdata_split.h" #include "backend/optimizer/ascend/ir_fission/transdata_split.h"
#include "backend/optimizer/ascend/ir_fission/topk_split.h" #include "backend/optimizer/ascend/ir_fission/topk_split.h"
#include "backend/optimizer/ascend/ir_fission/lin_space_fission.h"
#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h" #include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h" #include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h" #include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
@ -165,6 +166,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>()); ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>()); ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>()); ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>()); ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>()); ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
@ -324,6 +326,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());

View File

@ -0,0 +1,123 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/lin_space_fission.h"
#include <vector>
#include <memory>
#include <string>
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kLinSpaceInputNum = 3;
constexpr size_t kFloat32Len = 4;
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
// 1 get tensor value of input num
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_num = cnode->input(kLinSpaceInputNum);
MS_EXCEPTION_IF_NULL(input_num);
if (!IsValueNode<tensor::Tensor>(input_num)) {
return nullptr;
}
ValuePtr value = GetValueNode(input_num);
MS_EXCEPTION_IF_NULL(value);
auto tensor = value->cast<tensor::TensorPtr>();
MS_EXCEPTION_IF_NULL(tensor);
int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(data);
// 2 create tensor
int64_t assist_num = *data;
std::vector<int64_t> assist_shape = {assist_num};
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat32);
MS_EXCEPTION_IF_NULL(tensor_type);
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), assist_shape);
MS_EXCEPTION_IF_NULL(assist_tensor);
assist_tensor->set_device_info(device_info);
// 3 set value of tensor
auto data_ptr = assist_tensor->data_c();
MS_EXCEPTION_IF_NULL(data_ptr);
std::vector<float> float_data;
size_t data_num = LongToSize(assist_num);
for (size_t i = 0; i < data_num; ++i) {
float_data.emplace_back(static_cast<float>(i));
}
auto elem_num = assist_num * kFloat32Len;
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(assist_tensor->data().nbytes()), float_data.data(), elem_num);
if (ret_code != 0) {
MS_LOG(ERROR) << "Failed to copy data into Tensor.";
return nullptr;
}
return assist_tensor;
}
ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
tensor::TensorPtr assist_tensor = CreateTensor(node);
MS_EXCEPTION_IF_NULL(assist_tensor);
auto assist_const = std::make_shared<ValueNode>(assist_tensor);
MS_EXCEPTION_IF_NULL(assist_const);
auto assist_abstract = assist_tensor->ToAbstract();
assist_const->set_abstract(assist_abstract);
auto assist_kernel_info = std::make_shared<device::KernelInfo>();
MS_EXCEPTION_IF_NULL(assist_kernel_info);
assist_const->set_kernel_info(assist_kernel_info);
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
op_builder.SetOutputsFormat({kOpFormat_DEFAULT});
op_builder.SetOutputsDeviceType({kNumberTypeFloat32});
AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get());
return assist_const;
}
} // namespace
const BaseRef LinSpaceFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto lin_space_prim = std::make_shared<Primitive>(kLinSpaceOpName);
return VectorRef({lin_space_prim, Xs});
}
const AnfNodePtr LinSpaceFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto kernel_graph = graph->cast<KernelGraphPtr>();
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != kLinSpaceInputNum + 1) {
MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kLinSpaceInputNum << " inputs";
return nullptr;
}
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kLinSpaceOpName))};
auto assist_const = CreateValueNode(cnode);
new_inputs.push_back(assist_const);
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(cnode->abstract());
new_cnode->set_scope(cnode->scope());
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(assist_const);
MS_LOG(INFO) << "Split linspace op success.";
}
return new_cnode;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class LinSpaceFission : public PatternProcessPass {
public:
explicit LinSpaceFission(bool multigraph = true) : PatternProcessPass("lin_space_fission", multigraph) {}
~LinSpaceFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_

View File

@ -61,6 +61,7 @@ constexpr auto kReduceScatterOpName = "ReduceScatter";
constexpr auto kHostReduceScatterOpName = "HostReduceScatter"; constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
constexpr auto kMemCpyAsyncOpName = "memcpy_async"; constexpr auto kMemCpyAsyncOpName = "memcpy_async";
constexpr auto kTopKOpName = "TopK"; constexpr auto kTopKOpName = "TopK";
constexpr auto kLinSpaceOpName = "LinSpace";
constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches"; constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches";
constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce"; constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce";
constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate"; constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate";

View File

@ -25,6 +25,7 @@ from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore import nn from mindspore import nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F
__all__ = ['LSTM', 'LSTMCell'] __all__ = ['LSTM', 'LSTMCell']
@ -36,6 +37,10 @@ def _create_sequence_length(shape):
sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32) sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32)
return sequence_length return sequence_length
@constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
class LSTM(Cell): class LSTM(Cell):
r""" r"""
LSTM (Long Short-Term Memory) layer. LSTM (Long Short-Term Memory) layer.
@ -230,6 +235,9 @@ class LSTM(Cell):
x = self.transpose(x, (1, 0, 2)) x = self.transpose(x, (1, 0, 2))
h, c = hx h, c = hx
if self.is_ascend: if self.is_ascend:
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name)
x = self.cast(x, mstype.float16) x = self.cast(x, mstype.float16)
h = self.cast(h, mstype.float16) h = self.cast(h, mstype.float16)
c = self.cast(c, mstype.float16) c = self.cast(c, mstype.float16)

View File

@ -176,18 +176,17 @@ class LinSpace(Cell):
validator.check_positive_int(num, "num", self.cls_name) validator.check_positive_int(num, "num", self.cls_name)
self.is_single = bool(num == 1) self.is_single = bool(num == 1)
self.lin_space = inner.LinSpace() self.lin_space = P.LinSpace()
self.start = Tensor(start, mstype.float32) self.start = Tensor(start, mstype.float32)
self.stop = Tensor(stop, mstype.float32) self.stop = Tensor(stop, mstype.float32)
self.assist = Tensor(list(range(num)), mstype.float32) self.num = num
self.num = Tensor(num, mstype.int32)
self.start_array = Tensor([start], mstype.float32) self.start_array = Tensor([start], mstype.float32)
def construct(self): def construct(self):
if self.is_single: if self.is_single:
return self.start_array return self.start_array
lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num) lin_space_out = self.lin_space(self.start, self.stop, self.num)
return lin_space_out return lin_space_out
@constexpr @constexpr

View File

@ -22,7 +22,6 @@ from mindspore.ops import _selected_grad_ops as SG
from .. import functional as F from .. import functional as F
from .. import operations as P from .. import operations as P
from ..operations import _grad_ops as G from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
from .grad_base import bprop_getters from .grad_base import bprop_getters
@ -1188,7 +1187,7 @@ def get_bprop_inv(self):
return bprop return bprop
@bprop_getters.register(inner.LinSpace) @bprop_getters.register(P.LinSpace)
def get_bprop_lin_space(self): def get_bprop_lin_space(self):
"""Grad definition for `LinSpace` operation.""" """Grad definition for `LinSpace` operation."""

View File

@ -50,7 +50,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod, Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
LogicalNot, LogicalOr, MatMul, Maximum, LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual, Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus, NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy, Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
@ -130,6 +130,7 @@ __all__ = [
'BatchNorm', 'BatchNorm',
'MaxPool', 'MaxPool',
'TopK', 'TopK',
'LinSpace',
'Adam', 'Adam',
'FusedSparseAdam', 'FusedSparseAdam',
'FusedSparseLazyAdam', 'FusedSparseLazyAdam',

View File

@ -274,45 +274,6 @@ class Dequant(PrimitiveWithInfer):
return mstype.float16 return mstype.float16
class LinSpace(PrimitiveWithInfer):
r"""
Generates values in an interval. And return the corresponding interpolation accroding to assist.
Inputs:
- **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D.
- **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
- **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
- **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value.
With shape of 0-D.
Outputs:
Tensor, has the same shape as `assist`.
Examples:
>>> linspace = ops.LinSpace()
>>> assist = Tensor([5, 5.5], mindspore.float32)
>>> start = Tensor(1, mindspore.float32)
>>> stop = Tensor(10, mindspore.float32)
>>> num = Tensor(5, mindspore.int32)
>>> output = linspace(assist, start, stop, num)
>>> print(output)
[12.25, 13.375]
"""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, assist, start, stop, num):
return assist
def infer_dtype(self, assist, start, stop, num):
validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name)
args = {"assist": assist, "start": start, "stop": stop}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name)
return assist
class MatrixDiag(PrimitiveWithInfer): class MatrixDiag(PrimitiveWithInfer):
""" """
Returns a batched diagonal tensor with a given batched diagonal values. Returns a batched diagonal tensor with a given batched diagonal values.

View File

@ -3942,3 +3942,46 @@ class Eps(PrimitiveWithInfer):
'dtype': input_x['dtype'], 'dtype': input_x['dtype'],
} }
return out return out
class LinSpace(PrimitiveWithInfer):
r"""
Generates values in an interval and returns the corresponding interpolation accroding to assist.
Inputs:
- **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
- **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
- **num** (int) - Ticks number in the interval, the ticks include start and stop value.
Outputs:
Tensor, has the same shape as `assist`.
Examples:
>>> linspace = P.LinSpace()
>>> start = Tensor(1, mindspore.float32)
>>> stop = Tensor(10, mindspore.float32)
>>> num = 5
>>> output = linspace(start, stop, num)
>>> print(output)
[ 1. 3.25 5.5 7.75 10. ]
"""
@prim_attr_register
def __init__(self):
"""Initialize LinSpace"""
def __infer__(self, start, stop, num):
args = {"start": start['dtype'], "stop": start['dtype']}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name)
start_shape = start['shape']
stop_shape = stop['shape']
validator.check_equal_int(len(start_shape), 0, "rank of start_shape", self.name)
validator.check_equal_int(len(stop_shape), 0, "rank of stop_shape", self.name)
num_v = num['value']
validator.check_value_type('num', num_v, [int], self.name)
validator.check_positive_int(num_v, "num", self.name)
out_shape = [num_v]
out = {'shape': out_shape,
'dtype': start['dtype'],
'value': None}
return out

View File

@ -2202,11 +2202,10 @@ test_case_array_ops = [
'skip': ['backward'], 'skip': ['backward'],
}), }),
('LinSpace', { ('LinSpace', {
'block': inner.LinSpace(), 'block': P.LinSpace(),
'desc_inputs': [Tensor([5, 5.5], mstype.float32), 'desc_const': [5],
Tensor(1, mstype.float32), 'desc_inputs': [Tensor(1, mstype.float32),
Tensor(10, mstype.float32), Tensor(10, mstype.float32)],
Tensor(5, mstype.int32)],
'skip': ['backward'], 'skip': ['backward'],
}), }),
('MatrixDiag', { ('MatrixDiag', {