forked from mindspore-Ecosystem/mindspore
!19789 Add Diag and DiagPart op for Ascend.
Merge pull request !19789 from liuxiao93/add-Diag-and-DiagPart
This commit is contained in:
commit
49c8d6a27e
|
@ -56,6 +56,8 @@
|
|||
#include "backend/optimizer/ascend/ir_fission/topk_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/lin_space_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/space_to_depth_split.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/diag_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/diag_part_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/max_pool3d_grad_grad_fission.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h"
|
||||
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_grad_fusion.h"
|
||||
|
@ -179,6 +181,8 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagPartFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DGradFusion>());
|
||||
|
@ -324,6 +328,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
|||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<SpaceToDepthSplit>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<DiagPartFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DFusion>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DGradFusion>());
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/diag_fission.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kDiagInputNum = 1;
|
||||
|
||||
template <typename T>
|
||||
void SetAssistTensorData(void *data, T value, size_t dims_size) {
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
auto tensor_data = reinterpret_cast<T *>(data);
|
||||
MS_EXCEPTION_IF_NULL(tensor_data);
|
||||
for (size_t i = 0; i < dims_size; ++i) {
|
||||
tensor_data[(1 + dims_size) * i] = value;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ValueNodePtr DiagFission::CreateAssistNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const std::vector<size_t> &ori_shape) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::vector<size_t> output_shape(ori_shape);
|
||||
size_t dims = 1;
|
||||
for (size_t i = 0; i < ori_shape.size(); i++) {
|
||||
dims = dims * ori_shape[i];
|
||||
}
|
||||
output_shape.insert(output_shape.end(), ori_shape.begin(), ori_shape.end());
|
||||
auto type = AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
std::vector<int64_t> assist_shape;
|
||||
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(assist_shape), SizeToLong);
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, assist_shape);
|
||||
AbstractBasePtr x_abstract;
|
||||
if (type == kNumberTypeInt32) {
|
||||
SetAssistTensorData<int32_t>(tensor->data_c(), 1, dims);
|
||||
x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
|
||||
} else if (type == kNumberTypeFloat16) {
|
||||
SetAssistTensorData<float16>(tensor->data_c(), float16(static_cast<float>(1)), dims);
|
||||
x_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, assist_shape);
|
||||
} else {
|
||||
SetAssistTensorData<float>(tensor->data_c(), static_cast<float>(1), dims);
|
||||
x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat, assist_shape);
|
||||
}
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto assist_value_node = kernel_graph->NewValueNode(x_abstract, tensor);
|
||||
kernel_graph->AddValueNodeToGraph(assist_value_node);
|
||||
AnfAlgo::SetOutputInferTypeAndShape({type}, {output_shape}, assist_value_node.get());
|
||||
return assist_value_node;
|
||||
}
|
||||
|
||||
const BaseRef DiagFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto diag_prim = std::make_shared<Primitive>(prim::kPrimDiag->name());
|
||||
return VectorRef({diag_prim, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr DiagFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_graph = graph->cast<KernelGraphPtr>();
|
||||
auto diag_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(diag_cnode);
|
||||
if (diag_cnode->size() != kDiagInputNum + 1) {
|
||||
MS_LOG(INFO) << "The node " << diag_cnode->DebugString() << " is not equal to " << kDiagInputNum << " inputs";
|
||||
return nullptr;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetOutputInferShape(diag_cnode->inputs()[kIndex1], 0);
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimDiag->name()))};
|
||||
auto assist_const = CreateAssistNode(graph, diag_cnode, input_shape);
|
||||
new_inputs.insert(new_inputs.end(), diag_cnode->inputs().begin() + 1, diag_cnode->inputs().end());
|
||||
new_inputs.push_back(assist_const);
|
||||
CNodePtr new_cnode = graph->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(diag_cnode->abstract());
|
||||
new_cnode->set_scope(diag_cnode->scope());
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->AddValueNodeToGraph(assist_const);
|
||||
MS_LOG(INFO) << "Add assist tensor for diag op success.";
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_FISSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DiagFission : public PatternProcessPass {
|
||||
public:
|
||||
explicit DiagFission(const std::string name = "Diag_fission", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
~DiagFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
ValueNodePtr CreateAssistNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const std::vector<size_t> &ori_shape) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_FISSION_H_
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ir_fission/diag_part_fission.h"
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
const BaseRef DiagPartFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto diag_apart_prim = std::make_shared<Primitive>(prim::kPrimDiagPart->name());
|
||||
return VectorRef({diag_apart_prim, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr DiagPartFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto diag_part_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(diag_part_cnode);
|
||||
constexpr size_t kDiagPartInputNum = 1;
|
||||
if (diag_part_cnode->size() != kDiagPartInputNum + 1) {
|
||||
MS_LOG(INFO) << "The node " << diag_part_cnode->DebugString() << " is not equal to " << kDiagPartInputNum
|
||||
<< " inputs";
|
||||
return nullptr;
|
||||
}
|
||||
auto out_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||
std::vector<AnfNodePtr> new_node_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimDiagPart->name()))};
|
||||
auto assist_node = CreateAssistNode(func_graph, diag_part_cnode, out_shape);
|
||||
new_node_inputs.insert(new_node_inputs.end(), diag_part_cnode->inputs().begin() + 1, diag_part_cnode->inputs().end());
|
||||
new_node_inputs.push_back(assist_node);
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(diag_part_cnode->abstract());
|
||||
new_cnode->set_scope(diag_part_cnode->scope());
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->AddValueNodeToGraph(assist_node);
|
||||
MS_LOG(INFO) << "Add assist tensor for DiagPart op success.";
|
||||
}
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_PART_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_PART_FISSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ir_fission/diag_fission.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DiagPartFission : public DiagFission {
|
||||
public:
|
||||
explicit DiagPartFission(bool multigraph = true) : DiagFission("diag_part_fission", multigraph) {}
|
||||
~DiagPartFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_PART_FISSION_H_
|
|
@ -79,6 +79,8 @@ constexpr auto kFastGeLUGrad = "FastGeLUGrad";
|
|||
constexpr auto kStridedSlice = "StridedSlice";
|
||||
constexpr auto kZerosLike = "ZerosLike";
|
||||
constexpr auto kOnesLike = "OnesLike";
|
||||
constexpr auto kDiag = "Diag";
|
||||
constexpr auto kDiagPart = "DiagPart";
|
||||
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
|
||||
constexpr auto kTranspose = "Transpose";
|
||||
|
||||
|
@ -233,6 +235,8 @@ inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeG
|
|||
inline const PrimitivePtr kPrimResizeNearestNeighbor = std::make_shared<Primitive>("ResizeNearestNeighbor");
|
||||
inline const PrimitivePtr kPrimSort = std::make_shared<Primitive>("Sort");
|
||||
inline const PrimitivePtr kPrimMaskedSelect = std::make_shared<Primitive>("MaskedSelect");
|
||||
inline const PrimitivePtr kPrimDiag = std::make_shared<Primitive>(kDiag);
|
||||
inline const PrimitivePtr kPrimDiagPart = std::make_shared<Primitive>(kDiagPart);
|
||||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/diag.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr DiagInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("input rank", input_shape.size(), kGreaterEqual, 1, primitive->name());
|
||||
std::vector<int64_t> out_shape(input_shape);
|
||||
out_shape.insert(out_shape.end(), input_shape.begin(), input_shape.end());
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr PartInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input type", x_dtype, common_valid_types, primitive->name());
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr DiagInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return abstract::MakeAbstract(DiagInferShape(primitive, input_args), PartInferType(primitive, input_args));
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_DIAG_H_
|
||||
#define MINDSPORE_CORE_OPS_DIAG_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class Diag : public PrimitiveC {
|
||||
public:
|
||||
Diag() : PrimitiveC(prim::kPrimDiag->name()) { InitIOName({"input_x"}, {"output"}); }
|
||||
~Diag() = default;
|
||||
MS_DECLARE_PARENT(Diag, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr DiagInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimDiagPtr = std::shared_ptr<Diag>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_DIAG_H_
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/diag_part.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kScaleNum = 2;
|
||||
|
||||
abstract::ShapePtr DiagPartInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
if ((input_shape.size() % kScaleNum) != 0 || input_shape.size() == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For DiagPart, input rank must be non-zero and even, but got rank "
|
||||
<< input_shape.size();
|
||||
}
|
||||
auto length = input_shape.size() / kScaleNum;
|
||||
std::vector<int64_t> out_shape;
|
||||
for (size_t i = 0; i < length; i++) {
|
||||
CheckAndConvertUtils::Check("input_shape[i + rank(input_shape) / 2]", input_shape[i + length], kEqual,
|
||||
"input_shape[i]", input_shape[i], op_name, ValueError);
|
||||
out_shape.emplace_back(input_shape[i]);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr DiagPartInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x_dtype = input_args[0]->BuildType();
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input type", x_dtype, common_valid_types, primitive->name());
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr DiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
return abstract::MakeAbstract(DiagPartInferShape(primitive, input_args), DiagPartInferType(primitive, input_args));
|
||||
}
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_DIAG_PART_H_
|
||||
#define MINDSPORE_CORE_OPS_DIAG_PART_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class DiagPart : public PrimitiveC {
|
||||
public:
|
||||
DiagPart() : PrimitiveC(prim::kPrimDiagPart->name()) { InitIOName({"input_x"}, {"output"}); }
|
||||
~DiagPart() = default;
|
||||
MS_DECLARE_PARENT(DiagPart, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr DiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimDiagPartPtr = std::shared_ptr<DiagPart>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_DIAG_PART_H_
|
|
@ -414,8 +414,9 @@ abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &
|
|||
return shape;
|
||||
}
|
||||
|
||||
void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, const string &,
|
||||
int64_t value, const string &prim_name, ExceptionType exception_type) {
|
||||
void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type,
|
||||
const string &value_name, int64_t value, const string &prim_name,
|
||||
ExceptionType exception_type) {
|
||||
auto iter = kCompareMap<float>.find(compare_type);
|
||||
if (iter == kCompareMap<float>.end()) {
|
||||
MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map";
|
||||
|
@ -433,8 +434,8 @@ void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, Comp
|
|||
if (iter_to_string == kCompareToString.end()) {
|
||||
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
|
||||
}
|
||||
MS_EXCEPTION(exception_type) << buffer.str() << arg_name << " should be " << iter_to_string->second << value
|
||||
<< " but got " << arg_value;
|
||||
MS_EXCEPTION(exception_type) << buffer.str() << arg_name << " should be " << iter_to_string->second << value_name
|
||||
<< ": " << value << ", but got " << arg_value;
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,
|
||||
|
|
|
@ -337,6 +337,8 @@ from .inplace_update import _inplace_update_tbe
|
|||
from .splitv import _split_v_tbe
|
||||
from .in_top_k import _in_top_k_tbe
|
||||
from .lin_space import _lin_space_tbe
|
||||
from .diag import _diag_tbe
|
||||
from .diag_part import _diag_part_tbe
|
||||
from .matrix_diag import _matrix_diag_tbe
|
||||
from .matrix_diag_part import _matrix_diag_part_tbe
|
||||
from .matrix_set_diag import _matrix_set_diag_tbe
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""DiagD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
diag_d_op_info = TBERegOp("Diag") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("diag_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("diag_d") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "assist", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(diag_d_op_info)
|
||||
def _diag_tbe():
|
||||
"""DiagD TBE register"""
|
||||
return
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""DiagPartD op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
diag_part_d_op_info = TBERegOp("DiagPart") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("diag_part_d.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("diag_part_d") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "assist", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(diag_part_d_op_info)
|
||||
def _diag_part_tbe():
|
||||
"""DiagPartD TBE register"""
|
||||
return
|
|
@ -3285,6 +3285,13 @@ class Diag(PrimitiveWithInfer):
|
|||
Outputs:
|
||||
Tensor, has the same dtype as the `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
ValueError: If rank of `input_x` is less than 1.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor([1, 2, 3, 4])
|
||||
>>> diag = ops.Diag()
|
||||
|
@ -3330,11 +3337,19 @@ class DiagPart(PrimitiveWithInfer):
|
|||
:math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - tensor of rank k where k is even and not zero.
|
||||
- **input_x** (Tensor) - The input tensor of rank 2k, k is not zero.
|
||||
|
||||
Outputs:
|
||||
Tensor, the extracted diagonal has the same dtype as the `input_x`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `input_x` is not a Tensor.
|
||||
ValueError: If rank of `input_x` is not even or zero.
|
||||
ValueError: If input_shape[i] is not equal to input_shape[i + len(input_shape)/2].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples
|
||||
>>> input_x = Tensor([[1, 0, 0, 0],
|
||||
... [0, 2, 0, 0],
|
||||
|
|
Loading…
Reference in New Issue