forked from mindspore-Ecosystem/mindspore
!8920 Adapt ops LinSpace for Ascend.
From: @liu_xiao_93 Reviewed-by: @liangchenghui,@linqingke,@liangchenghui Signed-off-by: @liangchenghui,@liangchenghui
This commit is contained in:
commit
22d683a805
|
@ -51,6 +51,7 @@
|
||||||
#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h"
|
#include "backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/transdata_split.h"
|
#include "backend/optimizer/ascend/ir_fission/transdata_split.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/topk_split.h"
|
#include "backend/optimizer/ascend/ir_fission/topk_split.h"
|
||||||
|
#include "backend/optimizer/ascend/ir_fission/lin_space_fission.h"
|
||||||
#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h"
|
#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h"
|
||||||
#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h"
|
#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h"
|
||||||
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
|
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
|
||||||
|
@ -165,6 +166,7 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
||||||
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
|
ir_fusion_pm->AddPass(std::make_shared<ReshapeTransposeFusion>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
|
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
|
ir_fusion_pm->AddPass(std::make_shared<MomentumLossscaleFusion>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
|
ir_fusion_pm->AddPass(std::make_shared<MulAddFusion>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
|
ir_fusion_pm->AddPass(std::make_shared<MulAddNFusion>());
|
||||||
|
@ -324,6 +326,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
||||||
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
|
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
|
||||||
|
|
|
@ -0,0 +1,123 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/optimizer/ascend/ir_fission/lin_space_fission.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "frontend/optimizer/opt.h"
|
||||||
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kLinSpaceInputNum = 3;
|
||||||
|
constexpr size_t kFloat32Len = 4;
|
||||||
|
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||||
|
// 1 get tensor value of input num
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto input_num = cnode->input(kLinSpaceInputNum);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_num);
|
||||||
|
if (!IsValueNode<tensor::Tensor>(input_num)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ValuePtr value = GetValueNode(input_num);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
|
int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c());
|
||||||
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
|
|
||||||
|
// 2 create tensor
|
||||||
|
int64_t assist_num = *data;
|
||||||
|
std::vector<int64_t> assist_shape = {assist_num};
|
||||||
|
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat32);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||||
|
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
|
||||||
|
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), assist_shape);
|
||||||
|
MS_EXCEPTION_IF_NULL(assist_tensor);
|
||||||
|
assist_tensor->set_device_info(device_info);
|
||||||
|
|
||||||
|
// 3 set value of tensor
|
||||||
|
auto data_ptr = assist_tensor->data_c();
|
||||||
|
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||||
|
std::vector<float> float_data;
|
||||||
|
size_t data_num = LongToSize(assist_num);
|
||||||
|
for (size_t i = 0; i < data_num; ++i) {
|
||||||
|
float_data.emplace_back(static_cast<float>(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto elem_num = assist_num * kFloat32Len;
|
||||||
|
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(assist_tensor->data().nbytes()), float_data.data(), elem_num);
|
||||||
|
if (ret_code != 0) {
|
||||||
|
MS_LOG(ERROR) << "Failed to copy data into Tensor.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return assist_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
|
||||||
|
tensor::TensorPtr assist_tensor = CreateTensor(node);
|
||||||
|
MS_EXCEPTION_IF_NULL(assist_tensor);
|
||||||
|
auto assist_const = std::make_shared<ValueNode>(assist_tensor);
|
||||||
|
MS_EXCEPTION_IF_NULL(assist_const);
|
||||||
|
auto assist_abstract = assist_tensor->ToAbstract();
|
||||||
|
assist_const->set_abstract(assist_abstract);
|
||||||
|
auto assist_kernel_info = std::make_shared<device::KernelInfo>();
|
||||||
|
MS_EXCEPTION_IF_NULL(assist_kernel_info);
|
||||||
|
assist_const->set_kernel_info(assist_kernel_info);
|
||||||
|
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
|
||||||
|
op_builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||||
|
op_builder.SetOutputsDeviceType({kNumberTypeFloat32});
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get());
|
||||||
|
return assist_const;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef LinSpaceFission::DefinePattern() const {
|
||||||
|
VarPtr Xs = std::make_shared<SeqVar>();
|
||||||
|
auto lin_space_prim = std::make_shared<Primitive>(kLinSpaceOpName);
|
||||||
|
return VectorRef({lin_space_prim, Xs});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr LinSpaceFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (cnode->size() != kLinSpaceInputNum + 1) {
|
||||||
|
MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kLinSpaceInputNum << " inputs";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kLinSpaceOpName))};
|
||||||
|
auto assist_const = CreateValueNode(cnode);
|
||||||
|
new_inputs.push_back(assist_const);
|
||||||
|
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||||
|
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||||
|
new_cnode->set_abstract(cnode->abstract());
|
||||||
|
new_cnode->set_scope(cnode->scope());
|
||||||
|
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||||
|
if (kernel_graph != nullptr) {
|
||||||
|
kernel_graph->AddValueNodeToGraph(assist_const);
|
||||||
|
MS_LOG(INFO) << "Split linspace op success.";
|
||||||
|
}
|
||||||
|
return new_cnode;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_
|
||||||
|
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class LinSpaceFission : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit LinSpaceFission(bool multigraph = true) : PatternProcessPass("lin_space_fission", multigraph) {}
|
||||||
|
~LinSpaceFission() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_
|
|
@ -61,6 +61,7 @@ constexpr auto kReduceScatterOpName = "ReduceScatter";
|
||||||
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
|
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
|
||||||
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
|
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
|
||||||
constexpr auto kTopKOpName = "TopK";
|
constexpr auto kTopKOpName = "TopK";
|
||||||
|
constexpr auto kLinSpaceOpName = "LinSpace";
|
||||||
constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches";
|
constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches";
|
||||||
constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce";
|
constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce";
|
||||||
constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate";
|
constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate";
|
||||||
|
|
|
@ -25,6 +25,7 @@ from mindspore.common.tensor import Tensor
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore import nn
|
from mindspore import nn
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.ops import functional as F
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['LSTM', 'LSTMCell']
|
__all__ = ['LSTM', 'LSTMCell']
|
||||||
|
@ -36,6 +37,10 @@ def _create_sequence_length(shape):
|
||||||
sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32)
|
sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32)
|
||||||
return sequence_length
|
return sequence_length
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||||
|
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||||
|
|
||||||
class LSTM(Cell):
|
class LSTM(Cell):
|
||||||
r"""
|
r"""
|
||||||
LSTM (Long Short-Term Memory) layer.
|
LSTM (Long Short-Term Memory) layer.
|
||||||
|
@ -230,6 +235,9 @@ class LSTM(Cell):
|
||||||
x = self.transpose(x, (1, 0, 2))
|
x = self.transpose(x, (1, 0, 2))
|
||||||
h, c = hx
|
h, c = hx
|
||||||
if self.is_ascend:
|
if self.is_ascend:
|
||||||
|
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
|
_check_input_dtype(F.dtype(h), "h", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
|
_check_input_dtype(F.dtype(c), "c", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
x = self.cast(x, mstype.float16)
|
x = self.cast(x, mstype.float16)
|
||||||
h = self.cast(h, mstype.float16)
|
h = self.cast(h, mstype.float16)
|
||||||
c = self.cast(c, mstype.float16)
|
c = self.cast(c, mstype.float16)
|
||||||
|
|
|
@ -176,18 +176,17 @@ class LinSpace(Cell):
|
||||||
validator.check_positive_int(num, "num", self.cls_name)
|
validator.check_positive_int(num, "num", self.cls_name)
|
||||||
|
|
||||||
self.is_single = bool(num == 1)
|
self.is_single = bool(num == 1)
|
||||||
self.lin_space = inner.LinSpace()
|
self.lin_space = P.LinSpace()
|
||||||
self.start = Tensor(start, mstype.float32)
|
self.start = Tensor(start, mstype.float32)
|
||||||
self.stop = Tensor(stop, mstype.float32)
|
self.stop = Tensor(stop, mstype.float32)
|
||||||
self.assist = Tensor(list(range(num)), mstype.float32)
|
self.num = num
|
||||||
self.num = Tensor(num, mstype.int32)
|
|
||||||
self.start_array = Tensor([start], mstype.float32)
|
self.start_array = Tensor([start], mstype.float32)
|
||||||
|
|
||||||
def construct(self):
|
def construct(self):
|
||||||
if self.is_single:
|
if self.is_single:
|
||||||
return self.start_array
|
return self.start_array
|
||||||
|
|
||||||
lin_space_out = self.lin_space(self.assist, self.start, self.stop, self.num)
|
lin_space_out = self.lin_space(self.start, self.stop, self.num)
|
||||||
return lin_space_out
|
return lin_space_out
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
|
|
@ -22,7 +22,6 @@ from mindspore.ops import _selected_grad_ops as SG
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ..operations import _grad_ops as G
|
from ..operations import _grad_ops as G
|
||||||
from ..operations import _inner_ops as inner
|
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
|
from ..functional import broadcast_gradient_args, reduced_shape, tuple_div
|
||||||
from .grad_base import bprop_getters
|
from .grad_base import bprop_getters
|
||||||
|
@ -1188,7 +1187,7 @@ def get_bprop_inv(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(inner.LinSpace)
|
@bprop_getters.register(P.LinSpace)
|
||||||
def get_bprop_lin_space(self):
|
def get_bprop_lin_space(self):
|
||||||
"""Grad definition for `LinSpace` operation."""
|
"""Grad definition for `LinSpace` operation."""
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
||||||
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
|
||||||
LogicalNot, LogicalOr, MatMul, Maximum,
|
LogicalNot, LogicalOr, MatMul, Maximum,
|
||||||
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
Minimum, Mul, Neg, NMSWithMask, NotEqual,
|
||||||
NPUAllocFloatStatus, NPUClearFloatStatus,
|
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace,
|
||||||
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
|
||||||
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
Reciprocal, CumSum, HistogramFixedWidth, SquaredDifference, Xdivy, Xlogy,
|
||||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||||
|
@ -130,6 +130,7 @@ __all__ = [
|
||||||
'BatchNorm',
|
'BatchNorm',
|
||||||
'MaxPool',
|
'MaxPool',
|
||||||
'TopK',
|
'TopK',
|
||||||
|
'LinSpace',
|
||||||
'Adam',
|
'Adam',
|
||||||
'FusedSparseAdam',
|
'FusedSparseAdam',
|
||||||
'FusedSparseLazyAdam',
|
'FusedSparseLazyAdam',
|
||||||
|
|
|
@ -274,45 +274,6 @@ class Dequant(PrimitiveWithInfer):
|
||||||
return mstype.float16
|
return mstype.float16
|
||||||
|
|
||||||
|
|
||||||
class LinSpace(PrimitiveWithInfer):
|
|
||||||
r"""
|
|
||||||
Generates values in an interval. And return the corresponding interpolation accroding to assist.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **assist** (Tensor[float32]) - The assist value, With shape of 0-D or 1-D.
|
|
||||||
- **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
|
|
||||||
- **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
|
|
||||||
- **num** (Tensor[int32]) - ticks number in the interval, the ticks include start and stop value.
|
|
||||||
With shape of 0-D.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, has the same shape as `assist`.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> linspace = ops.LinSpace()
|
|
||||||
>>> assist = Tensor([5, 5.5], mindspore.float32)
|
|
||||||
>>> start = Tensor(1, mindspore.float32)
|
|
||||||
>>> stop = Tensor(10, mindspore.float32)
|
|
||||||
>>> num = Tensor(5, mindspore.int32)
|
|
||||||
>>> output = linspace(assist, start, stop, num)
|
|
||||||
>>> print(output)
|
|
||||||
[12.25, 13.375]
|
|
||||||
"""
|
|
||||||
|
|
||||||
@prim_attr_register
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def infer_shape(self, assist, start, stop, num):
|
|
||||||
return assist
|
|
||||||
|
|
||||||
def infer_dtype(self, assist, start, stop, num):
|
|
||||||
validator.check_tensor_dtype_valid("num", num, (mstype.int32,), self.name)
|
|
||||||
args = {"assist": assist, "start": start, "stop": stop}
|
|
||||||
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name)
|
|
||||||
return assist
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixDiag(PrimitiveWithInfer):
|
class MatrixDiag(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Returns a batched diagonal tensor with a given batched diagonal values.
|
Returns a batched diagonal tensor with a given batched diagonal values.
|
||||||
|
|
|
@ -3942,3 +3942,46 @@ class Eps(PrimitiveWithInfer):
|
||||||
'dtype': input_x['dtype'],
|
'dtype': input_x['dtype'],
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class LinSpace(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Generates values in an interval and returns the corresponding interpolation accroding to assist.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **start** (Tensor[float32]) - The start of interval, With shape of 0-D.
|
||||||
|
- **stop** (Tensor[float32]) - The end of interval, With shape of 0-D.
|
||||||
|
- **num** (int) - Ticks number in the interval, the ticks include start and stop value.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the same shape as `assist`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> linspace = P.LinSpace()
|
||||||
|
>>> start = Tensor(1, mindspore.float32)
|
||||||
|
>>> stop = Tensor(10, mindspore.float32)
|
||||||
|
>>> num = 5
|
||||||
|
>>> output = linspace(start, stop, num)
|
||||||
|
>>> print(output)
|
||||||
|
[ 1. 3.25 5.5 7.75 10. ]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize LinSpace"""
|
||||||
|
|
||||||
|
def __infer__(self, start, stop, num):
|
||||||
|
args = {"start": start['dtype'], "stop": start['dtype']}
|
||||||
|
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float32,), self.name)
|
||||||
|
start_shape = start['shape']
|
||||||
|
stop_shape = stop['shape']
|
||||||
|
validator.check_equal_int(len(start_shape), 0, "rank of start_shape", self.name)
|
||||||
|
validator.check_equal_int(len(stop_shape), 0, "rank of stop_shape", self.name)
|
||||||
|
num_v = num['value']
|
||||||
|
validator.check_value_type('num', num_v, [int], self.name)
|
||||||
|
validator.check_positive_int(num_v, "num", self.name)
|
||||||
|
out_shape = [num_v]
|
||||||
|
out = {'shape': out_shape,
|
||||||
|
'dtype': start['dtype'],
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
|
@ -2202,11 +2202,10 @@ test_case_array_ops = [
|
||||||
'skip': ['backward'],
|
'skip': ['backward'],
|
||||||
}),
|
}),
|
||||||
('LinSpace', {
|
('LinSpace', {
|
||||||
'block': inner.LinSpace(),
|
'block': P.LinSpace(),
|
||||||
'desc_inputs': [Tensor([5, 5.5], mstype.float32),
|
'desc_const': [5],
|
||||||
Tensor(1, mstype.float32),
|
'desc_inputs': [Tensor(1, mstype.float32),
|
||||||
Tensor(10, mstype.float32),
|
Tensor(10, mstype.float32)],
|
||||||
Tensor(5, mstype.int32)],
|
|
||||||
'skip': ['backward'],
|
'skip': ['backward'],
|
||||||
}),
|
}),
|
||||||
('MatrixDiag', {
|
('MatrixDiag', {
|
||||||
|
|
Loading…
Reference in New Issue