forked from mindspore-Ecosystem/mindspore
!37722 [feat] [assistant] [I56J5P] aicpu operator linspace
Merge pull request !37722 from zhixinaa/linspace
This commit is contained in:
commit
a8c9c1b0e9
|
@ -167,9 +167,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
|
||||
addn_check_dump_ = MakeSubstitution(std::make_shared<AddNCheckDump>(), "addn_check_dump", prim::kPrimAddN);
|
||||
|
||||
// linspace
|
||||
lin_space_val_ = MakeSubstitution(std::make_shared<LinSpaceValue>(), "lin_space_val", prim::kPrimLinSpace);
|
||||
|
||||
// AccumulateNV2
|
||||
accumulaten_eliminater_ =
|
||||
MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2);
|
||||
|
|
|
@ -90,9 +90,6 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr addn_zero_filter_;
|
||||
SubstitutionPtr addn_check_dump_;
|
||||
|
||||
// linspace
|
||||
SubstitutionPtr lin_space_val_;
|
||||
|
||||
// AccumulateNV2
|
||||
SubstitutionPtr accumulaten_eliminater_;
|
||||
|
||||
|
|
|
@ -137,44 +137,6 @@ class ParallelVirtualNodeEliminater : public AnfVisitor {
|
|||
}
|
||||
};
|
||||
|
||||
class LinSpaceValue : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto func_graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto &inputs = cnode->inputs();
|
||||
constexpr size_t kInputSize = 4;
|
||||
if (inputs.size() != kInputSize) {
|
||||
return nullptr;
|
||||
}
|
||||
for (auto const item : inputs) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_val_ptr = inputs[kIndex3]->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_val_ptr);
|
||||
if (input_val_ptr->abstract()->isa<abstract::AbstractTensor>()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (GetValue<int64_t>(input_val_ptr->value()) != 1) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto &reshape_fn = python_adapter::GetPyFn("mindspore.ops.functional", "reshape");
|
||||
|
||||
auto reshape_fg = parse::ParsePythonCode(reshape_fn);
|
||||
auto res = std::make_shared<pipeline::Resource>();
|
||||
(void)parse::ResolveFuncGraph(reshape_fg, res);
|
||||
|
||||
auto shape = NewValueNode(MakeValue(std::vector<int64_t>{1}));
|
||||
AnfNodePtr reshape_node = func_graph->NewCNodeInOrder({NewValueNode(reshape_fg), inputs[1], shape});
|
||||
AnfNodePtr stop_grad_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimStopGradient), reshape_node});
|
||||
return stop_grad_node;
|
||||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimSameTypeShape, X, Y} -> X
|
||||
class SameEliminater : public AnfVisitor {
|
||||
public:
|
||||
|
|
|
@ -319,7 +319,6 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.tile_eliminate_,
|
||||
irpass.transpose_eliminate_,
|
||||
irpass.minmaximum_grad_,
|
||||
irpass.lin_space_val_,
|
||||
|
||||
// Arithmetic simplifications
|
||||
irpass.arithmetic_simplify_,
|
||||
|
|
|
@ -82,6 +82,7 @@ constexpr auto kMaskedSelect = "MaskedSelect";
|
|||
constexpr auto kMaskedSelectGrad = "MaskedSelectGrad";
|
||||
constexpr auto kDynamicStitch = "DynamicStitch";
|
||||
constexpr auto kSearchSorted = "SearchSorted";
|
||||
constexpr auto kLinSpace = "LinSpace";
|
||||
constexpr auto kResizeBilinear = "ResizeBilinear";
|
||||
constexpr auto kResizeBilinearGrad = "ResizeBilinearGrad";
|
||||
constexpr auto kTensorScatterElements = "TensorScatterElements";
|
||||
|
@ -119,9 +120,9 @@ constexpr auto kRandomShuffle = "RandomShuffle";
|
|||
constexpr auto kHSigmoid = "HSigmoid";
|
||||
constexpr auto kHSigmoidGrad = "HSigmoidGrad";
|
||||
|
||||
const std::set<std::string> kCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad,
|
||||
kDynamicStitch, kSearchSorted, kResizeBilinear,
|
||||
kResizeBilinearGrad, kTensorScatterElements, kUniqueConsecutive};
|
||||
const std::set<std::string> kCpuKernelOps{
|
||||
kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch, kSearchSorted,
|
||||
kResizeBilinear, kResizeBilinearGrad, kTensorScatterElements, kUniqueConsecutive, kLinSpace};
|
||||
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, kDropout3D,
|
||||
kDropout2D, kNonMaxSuppressionV3, kGetNext, kInitData, kPrint};
|
||||
const std::set<std::string> kCpuKernelBaseOps{kRandomChoiceWithMask,
|
||||
|
|
|
@ -64,7 +64,6 @@
|
|||
#include "plugin/device/ascend/optimizer/ir_fission/transdata_split.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/topk_split.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/conv2d_backprop_filter_mul_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/lin_space_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/space_to_depth_split.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/diag_fission.h"
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/diag_part_fission.h"
|
||||
|
@ -230,7 +229,6 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<Conv2dBackpropFilterMul>());
|
||||
// ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagPartFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DeformableOffsetsFusion>());
|
||||
|
@ -439,7 +437,6 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
ir_fusion_pm->AddPass(std::make_shared<Conv2dBackpropInputDilationFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<Conv2dBackpropFilterMul>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SpaceToDepthSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagPartFission>());
|
||||
|
|
|
@ -1,126 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/ir_fission/lin_space_fission.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kLinSpaceInputNum = 3;
|
||||
constexpr size_t kFloat32Len = 4;
|
||||
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||
// 1 get tensor value of input num
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_num = cnode->input(kLinSpaceInputNum);
|
||||
MS_EXCEPTION_IF_NULL(input_num);
|
||||
if (!IsValueNode<tensor::Tensor>(input_num)) {
|
||||
return nullptr;
|
||||
}
|
||||
ValuePtr value = GetValueNode(input_num);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
int32_t *data = static_cast<int32_t *>(tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
|
||||
// 2 create tensor
|
||||
int64_t assist_num = *data;
|
||||
std::vector<int64_t> assist_shape = {assist_num};
|
||||
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat32);
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
|
||||
tensor::TensorPtr assist_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), assist_shape);
|
||||
MS_EXCEPTION_IF_NULL(assist_tensor);
|
||||
assist_tensor->set_device_info(device_info);
|
||||
|
||||
// 3 set value of tensor
|
||||
auto data_ptr = assist_tensor->data_c();
|
||||
MS_EXCEPTION_IF_NULL(data_ptr);
|
||||
std::vector<float> float_data;
|
||||
size_t data_num = LongToSize(assist_num);
|
||||
for (size_t i = 0; i < data_num; ++i) {
|
||||
float_data.emplace_back(static_cast<float>(i));
|
||||
}
|
||||
|
||||
auto elem_num = data_num * kFloat32Len;
|
||||
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(assist_tensor->data().nbytes()), float_data.data(), elem_num);
|
||||
if (ret_code != 0) {
|
||||
MS_LOG(ERROR) << "Failed to copy data into Tensor while creating assist input for LinSpace op, memcpy_s errorno: "
|
||||
<< ret_code;
|
||||
return nullptr;
|
||||
}
|
||||
return assist_tensor;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
|
||||
tensor::TensorPtr assist_tensor = CreateTensor(node);
|
||||
MS_EXCEPTION_IF_NULL(assist_tensor);
|
||||
auto assist_const = std::make_shared<ValueNode>(assist_tensor);
|
||||
MS_EXCEPTION_IF_NULL(assist_const);
|
||||
auto assist_abstract = assist_tensor->ToAbstract();
|
||||
assist_const->set_abstract(assist_abstract);
|
||||
auto assist_kernel_info = std::make_shared<device::KernelInfo>();
|
||||
MS_EXCEPTION_IF_NULL(assist_kernel_info);
|
||||
assist_const->set_kernel_info(assist_kernel_info);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder op_builder;
|
||||
op_builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
op_builder.SetOutputsDeviceType({kNumberTypeFloat32});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(op_builder.Build(), assist_const.get());
|
||||
return assist_const;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef LinSpaceFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto lin_space_prim = std::make_shared<Primitive>(kLinSpaceOpName);
|
||||
return VectorRef({lin_space_prim, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr LinSpaceFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() != kLinSpaceInputNum + 1) {
|
||||
MS_LOG(INFO) << "The node " << cnode->DebugString() << " is not equal to " << kLinSpaceInputNum << " inputs";
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kLinSpaceOpName))};
|
||||
auto assist_const = CreateValueNode(cnode);
|
||||
new_inputs.push_back(assist_const);
|
||||
(void)new_inputs.insert(new_inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
|
||||
CNodePtr new_cnode = NewCNode(new_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
common::AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->AddValueNodeToGraph(assist_const);
|
||||
MS_LOG(INFO) << "Split linspace op success.";
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,32 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_
|
||||
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class LinSpaceFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit LinSpaceFission(bool multigraph = true) : PatternProcessPass("lin_space_fission", multigraph) {}
|
||||
~LinSpaceFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_LIN_SPACE_FUSION_H_
|
|
@ -30,6 +30,7 @@ using KernelRunFunc = LinSpaceCpuKernelMod::KernelRunFunc;
|
|||
bool LinSpaceCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
num_dtype_ = inputs[kIndex2]->GetDtype();
|
||||
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
|
@ -52,7 +53,6 @@ int LinSpaceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
|
|||
multi_dims_ = (batch_num_ != 1);
|
||||
|
||||
const auto dtype_size = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype());
|
||||
|
||||
// Deal with workspace_size_list_
|
||||
workspace_size_list_.clear();
|
||||
workspace_size_list_ = {LongToSize(batch_num_) * dtype_size};
|
||||
|
@ -64,7 +64,13 @@ template <typename T>
|
|||
bool LinSpaceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
const int64_t num = *reinterpret_cast<int64_t *>(inputs[kIndex2]->addr);
|
||||
int64_t num;
|
||||
if (num_dtype_ == kNumberTypeInt32) {
|
||||
int32_t num_val = *reinterpret_cast<int32_t *>(inputs[kIndex2]->addr);
|
||||
num = IntToLong(num_val);
|
||||
} else {
|
||||
num = *reinterpret_cast<int64_t *>(inputs[kIndex2]->addr);
|
||||
}
|
||||
// Deal wtih num equal to 1
|
||||
if (num == 1) {
|
||||
const auto input = inputs[kIndex0];
|
||||
|
@ -143,7 +149,18 @@ const std::vector<std::pair<KernelAttr, KernelRunFunc>> &LinSpaceCpuKernelMod::G
|
|||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&LinSpaceCpuKernelMod::LaunchKernel<double>},
|
||||
};
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&LinSpaceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&LinSpaceCpuKernelMod::LaunchKernel<double>}};
|
||||
return func_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, LinSpace, LinSpaceCpuKernelMod);
|
||||
|
|
|
@ -55,6 +55,7 @@ class LinSpaceCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
|
|||
const std::vector<AddressPtr> &outputs);
|
||||
int64_t batch_num_{0};
|
||||
bool multi_dims_{false};
|
||||
TypeId num_dtype_{kTypeUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,18 +27,11 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
#define IsNoneOrAnyValue(value_ptr) ((value_ptr->isa<None>()) || (value_ptr->isa<AnyValue>()))
|
||||
TypePtr LinSpaceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractScalar>(prim_name, input_args, kInputIndex2);
|
||||
|
||||
auto num_value = input_args[kInputIndex2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(num_value);
|
||||
if (!num_value->isa<Int64Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name << "], the 'num' must be a Int, but got "
|
||||
<< num_value->ToString();
|
||||
}
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
||||
|
||||
auto start_dtype = input_args[kInputIndex0]->BuildType();
|
||||
auto stop_dtype = input_args[kInputIndex1]->BuildType();
|
||||
|
@ -57,20 +50,56 @@ abstract::ShapePtr LinSpaceInferShape(const PrimitivePtr &primitive, const std::
|
|||
auto stop_shape_ptr = input_args[kInputIndex1]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(stop_shape_ptr);
|
||||
|
||||
auto num_value = input_args[kInputIndex2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(num_value);
|
||||
|
||||
bool is_compile = IsNoneOrAnyValue(num_value);
|
||||
// Do it later
|
||||
if (start_shape_ptr->IsDynamic() || stop_shape_ptr->IsDynamic()) {
|
||||
return input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
}
|
||||
|
||||
const auto start_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(start_shape_ptr)[kShape];
|
||||
const auto stop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(stop_shape_ptr)[kShape];
|
||||
|
||||
// Checked in LinSpaceInferType, num is a Scalar
|
||||
const auto num_value = input_args[kInputIndex2]->BuildValue();
|
||||
const int64_t num = num_value->cast<Int64ImmPtr>()->value();
|
||||
int64_t num = 0;
|
||||
if (!is_compile) {
|
||||
if (input_args[kInputIndex2]->isa<abstract::AbstractTensor>()) {
|
||||
if (num_value->isa<tensor::Tensor>()) {
|
||||
auto num_shape_ptr = input_args[kInputIndex2]->BuildShape();
|
||||
const auto num_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(num_shape_ptr)[kShape];
|
||||
if (num_shape.size() != 0) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
|
||||
<< "], the 'num' must be int or 0D int32/int64 Tensor, but got " << num_shape.size()
|
||||
<< "D Tensor.";
|
||||
}
|
||||
auto num_input = CheckAndConvertUtils::CheckTensorIntValue("num", num_value, prim_name);
|
||||
num = num_input[0];
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
|
||||
<< "], the 'num' must be int or 0D int32/int64 Tensor, but got "
|
||||
<< num_value->ToString() << ".";
|
||||
}
|
||||
} else if (input_args[kInputIndex2]->isa<abstract::AbstractScalar>()) {
|
||||
MS_EXCEPTION_IF_NULL(num_value);
|
||||
if (!num_value->isa<Int64Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
|
||||
<< "], the 'num' must be int or 0D int32/int64 Tensor, but got "
|
||||
<< num_value->ToString() << ".";
|
||||
}
|
||||
num = num_value->cast<Int64ImmPtr>()->value();
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << prim_name
|
||||
<< "], the 'num' must be int or 0D int32/int64 Tensor, but got " << num_value->ToString()
|
||||
<< ".";
|
||||
}
|
||||
} else {
|
||||
ShapeVector out_shape = {abstract::Shape::SHP_ANY};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
(void)CheckAndConvertUtils::CheckValue<int64_t>("num", num, kGreaterThan, 0, prim_name);
|
||||
|
||||
const auto start_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(start_shape_ptr)[kShape];
|
||||
const auto stop_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(stop_shape_ptr)[kShape];
|
||||
|
||||
size_t batch_rank = 0;
|
||||
if (primitive->HasAttr(kBatchRank)) {
|
||||
auto value_ptr = primitive->GetAttr(kBatchRank);
|
||||
|
@ -101,6 +130,7 @@ AbstractBasePtr LinSpaceInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
auto infer_shape = LinSpaceInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_HOST_DEPENDS(kNameLinSpace, {2});
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LinSpace, prim::kPrimLinSpace, LinSpaceInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -269,6 +269,7 @@ from .tril import _tril_aicpu
|
|||
from .bucketize import _bucketize_aicpu
|
||||
from .eye import _eye_aicpu
|
||||
from .logspace import _logspace_aicpu
|
||||
from .linspace import _lin_space_aicpu
|
||||
from .triu import _triu_aicpu
|
||||
from .dense_to_dense_set_operation import _dense_to_dense_set_operation_aicpu
|
||||
from .fractional_max_pool3d_with_fixed_ksize import _fractional_max_pool3d_with_fixed_ksize_aicpu
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LinSpace op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
lin_space_op_info = AiCPURegOp("LinSpace") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "start", "required") \
|
||||
.input(1, "stop", "required") \
|
||||
.input(2, "num", "required") \
|
||||
.output(0, "output", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(lin_space_op_info)
|
||||
def _lin_space_aicpu():
|
||||
"""LinSpace aicpu register"""
|
||||
return
|
|
@ -526,7 +526,6 @@ from .inplace_update import _inplace_update_tbe
|
|||
from .inplace_update_ds import _inplace_update_v2_ds_tbe
|
||||
from .split_v import _split_v_tbe
|
||||
from .in_top_k import _in_top_k_tbe
|
||||
from .lin_space import _lin_space_tbe
|
||||
from .diag import _diag_tbe
|
||||
from .diag_part import _diag_part_tbe
|
||||
from .matrix_diag import _matrix_diag_tbe
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LinSpace op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lin_space_op_info = TBERegOp("LinSpace") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lin_space.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lin_space_d") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("broadcast") \
|
||||
.input(0, "assist", False, "required", "all") \
|
||||
.input(1, "start", False, "required", "all") \
|
||||
.input(2, "stop", False, "required", "all") \
|
||||
.input(3, "num", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None, DataType.I32_None,
|
||||
DataType.F32_None,) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(lin_space_op_info)
|
||||
def _lin_space_tbe():
|
||||
"""LinSpace TBE register"""
|
||||
return
|
|
@ -2126,19 +2126,19 @@ def linspace(start, stop, num):
|
|||
\end{aligned}
|
||||
|
||||
Args:
|
||||
start (Tensor): Start value of interval. The tensor data type must be float32 and with shape of 0-D.
|
||||
stop (Tensor): Last value of interval. The tensor data type must be float32 and with shape of 0-D.
|
||||
num (int): Number of ticks in the interval, inclusive of start and stop.
|
||||
Must be positive int number.
|
||||
start (Tensor): Start value of interval. The tensor data type must be float32 or float64 and with shape of 0-D.
|
||||
stop (Tensor): Last value of interval. The tensor data type must be float32 or float64 and with shape of 0-D.
|
||||
num (Union[Tensor, int]): Number of ticks in the interval, inclusive of start and stop.
|
||||
Must be positive int number or 0D int32/int64 Tensor.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dtype as `start`, and the shape of :math:`(num)`
|
||||
|
||||
Raises:
|
||||
TypeError: If `start` or `stop` is not a Tensor.
|
||||
TypeError: If dtype of `start` or dtype of `stop` is not float32.
|
||||
TypeError: If dtype of `start` or dtype of `stop` is not float32 or float64.
|
||||
ValueError: If shape of `start` or shape of `stop` is not 0-D.
|
||||
TypeError: If `num` is not int.
|
||||
TypeError: If `num` is not int or 0D int32/int64 Tensor.
|
||||
ValueError: If `num` is not positive int number.
|
||||
|
||||
Supported Platforms:
|
||||
|
|
|
@ -5305,7 +5305,7 @@ class Eps(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class LinSpace(PrimitiveWithInfer):
|
||||
class LinSpace(Primitive):
|
||||
r"""
|
||||
Returns a Tensor whose value is `num` evenly spaced in the interval `start` and `stop` (including `start` and
|
||||
`stop`), and the length of the output Tensor is `num`.
|
||||
|
@ -5328,14 +5328,7 @@ class LinSpace(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize LinSpace"""
|
||||
|
||||
def check_elim(self, start, stop, num):
|
||||
if not isinstance(num, int):
|
||||
return False, None
|
||||
if num != 1:
|
||||
return False, None
|
||||
start_npy = start.asnumpy().reshape((1,))
|
||||
return True, Tensor(start_npy)
|
||||
self.init_prim_io_names(inputs=['start', 'stop', 'num'], outputs=['output'])
|
||||
|
||||
|
||||
class MatrixInverse(Primitive):
|
||||
|
|
|
@ -178,23 +178,3 @@ def test_lin_space_num():
|
|||
result_ms = ops.vmap(net, (0, 0))(start, stop).asnumpy()
|
||||
result_np = np.linspace(start_np, stop_np, num_np, axis=-1)
|
||||
assert np.allclose(result_ms, result_np)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_lin_space_num_1():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for LinSpace Net
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
|
||||
start_np = 5
|
||||
stop_np = 150
|
||||
|
||||
start = Tensor(start_np, dtype=mstype.float32)
|
||||
stop = Tensor(stop_np, dtype=mstype.float32)
|
||||
num = Tensor(1)
|
||||
with pytest.raises(TypeError):
|
||||
ops.linspace(start, stop, num)
|
||||
|
|
Loading…
Reference in New Issue