!19789 Add Diag and DiagPart op for Ascend.

Merge pull request !19789 from liuxiao93/add-Diag-and-DiagPart
This commit is contained in:
i-robot 2021-07-13 12:21:54 +00:00 committed by Gitee
commit 49c8d6a27e
15 changed files with 536 additions and 5 deletions

View File

@ -56,6 +56,8 @@
#include "backend/optimizer/ascend/ir_fission/topk_split.h"
#include "backend/optimizer/ascend/ir_fission/lin_space_fission.h"
#include "backend/optimizer/ascend/ir_fission/space_to_depth_split.h"
#include "backend/optimizer/ascend/ir_fission/diag_fission.h"
#include "backend/optimizer/ascend/ir_fission/diag_part_fission.h"
#include "backend/optimizer/ascend/ir_fission/max_pool3d_grad_grad_fission.h"
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/avgpool_3d_grad_fusion.h"
@ -179,6 +181,8 @@ void AddAscendIRFusionPass(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<TransposeReshapeFusion>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
ir_fusion_pm->AddPass(std::make_shared<DiagFission>());
ir_fusion_pm->AddPass(std::make_shared<DiagPartFission>());
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DFusion>());
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DGradFusion>());
@ -324,6 +328,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>());
ir_fusion_pm->AddPass(std::make_shared<SpaceToDepthSplit>());
ir_fusion_pm->AddPass(std::make_shared<DiagFission>());
ir_fusion_pm->AddPass(std::make_shared<DiagPartFission>());
ir_fusion_pm->AddPass(std::make_shared<MaxPool3DGradGradFission>());
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DFusion>());
ir_fusion_pm->AddPass(std::make_shared<AvgPool3DGradFusion>());

View File

@ -0,0 +1,104 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/diag_fission.h"
#include <algorithm>
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kDiagInputNum = 1;
template <typename T>
void SetAssistTensorData(void *data, T value, size_t dims_size) {
MS_EXCEPTION_IF_NULL(data);
auto tensor_data = reinterpret_cast<T *>(data);
MS_EXCEPTION_IF_NULL(tensor_data);
for (size_t i = 0; i < dims_size; ++i) {
tensor_data[(1 + dims_size) * i] = value;
}
}
} // namespace
ValueNodePtr DiagFission::CreateAssistNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<size_t> &ori_shape) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
std::vector<size_t> output_shape(ori_shape);
size_t dims = 1;
for (size_t i = 0; i < ori_shape.size(); i++) {
dims = dims * ori_shape[i];
}
output_shape.insert(output_shape.end(), ori_shape.begin(), ori_shape.end());
auto type = AnfAlgo::GetOutputInferDataType(node, 0);
std::vector<int64_t> assist_shape;
std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(assist_shape), SizeToLong);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, assist_shape);
AbstractBasePtr x_abstract;
if (type == kNumberTypeInt32) {
SetAssistTensorData<int32_t>(tensor->data_c(), 1, dims);
x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat16, assist_shape);
} else if (type == kNumberTypeFloat16) {
SetAssistTensorData<float16>(tensor->data_c(), float16(static_cast<float>(1)), dims);
x_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, assist_shape);
} else {
SetAssistTensorData<float>(tensor->data_c(), static_cast<float>(1), dims);
x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat, assist_shape);
}
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto assist_value_node = kernel_graph->NewValueNode(x_abstract, tensor);
kernel_graph->AddValueNodeToGraph(assist_value_node);
AnfAlgo::SetOutputInferTypeAndShape({type}, {output_shape}, assist_value_node.get());
return assist_value_node;
}
const BaseRef DiagFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto diag_prim = std::make_shared<Primitive>(prim::kPrimDiag->name());
return VectorRef({diag_prim, Xs});
}
const AnfNodePtr DiagFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto kernel_graph = graph->cast<KernelGraphPtr>();
auto diag_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(diag_cnode);
if (diag_cnode->size() != kDiagInputNum + 1) {
MS_LOG(INFO) << "The node " << diag_cnode->DebugString() << " is not equal to " << kDiagInputNum << " inputs";
return nullptr;
}
auto input_shape = AnfAlgo::GetOutputInferShape(diag_cnode->inputs()[kIndex1], 0);
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimDiag->name()))};
auto assist_const = CreateAssistNode(graph, diag_cnode, input_shape);
new_inputs.insert(new_inputs.end(), diag_cnode->inputs().begin() + 1, diag_cnode->inputs().end());
new_inputs.push_back(assist_const);
CNodePtr new_cnode = graph->NewCNode(new_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(diag_cnode->abstract());
new_cnode->set_scope(diag_cnode->scope());
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(assist_const);
MS_LOG(INFO) << "Add assist tensor for diag op success.";
}
return new_cnode;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_FISSION_H_
#include <vector>
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class DiagFission : public PatternProcessPass {
public:
explicit DiagFission(const std::string name = "Diag_fission", bool multigraph = true)
: PatternProcessPass(name, multigraph) {}
~DiagFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
protected:
ValueNodePtr CreateAssistNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const std::vector<size_t> &ori_shape) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_FISSION_H_

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ir_fission/diag_part_fission.h"
#include <memory>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore::opt {
const BaseRef DiagPartFission::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
auto diag_apart_prim = std::make_shared<Primitive>(prim::kPrimDiagPart->name());
return VectorRef({diag_apart_prim, Xs});
}
const AnfNodePtr DiagPartFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto diag_part_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(diag_part_cnode);
constexpr size_t kDiagPartInputNum = 1;
if (diag_part_cnode->size() != kDiagPartInputNum + 1) {
MS_LOG(INFO) << "The node " << diag_part_cnode->DebugString() << " is not equal to " << kDiagPartInputNum
<< " inputs";
return nullptr;
}
auto out_shape = AnfAlgo::GetOutputInferShape(node, 0);
std::vector<AnfNodePtr> new_node_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimDiagPart->name()))};
auto assist_node = CreateAssistNode(func_graph, diag_part_cnode, out_shape);
new_node_inputs.insert(new_node_inputs.end(), diag_part_cnode->inputs().begin() + 1, diag_part_cnode->inputs().end());
new_node_inputs.push_back(assist_node);
CNodePtr new_cnode = func_graph->NewCNode(new_node_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
new_cnode->set_abstract(diag_part_cnode->abstract());
new_cnode->set_scope(diag_part_cnode->scope());
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
if (kernel_graph != nullptr) {
kernel_graph->AddValueNodeToGraph(assist_node);
MS_LOG(INFO) << "Add assist tensor for DiagPart op success.";
}
return new_cnode;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,33 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_PART_FISSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_PART_FISSION_H_
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ir_fission/diag_fission.h"
namespace mindspore {
namespace opt {
class DiagPartFission : public DiagFission {
public:
explicit DiagPartFission(bool multigraph = true) : DiagFission("diag_part_fission", multigraph) {}
~DiagPartFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FISSION_DIAG_PART_FISSION_H_

View File

@ -79,6 +79,8 @@ constexpr auto kFastGeLUGrad = "FastGeLUGrad";
constexpr auto kStridedSlice = "StridedSlice";
constexpr auto kZerosLike = "ZerosLike";
constexpr auto kOnesLike = "OnesLike";
constexpr auto kDiag = "Diag";
constexpr auto kDiagPart = "DiagPart";
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
constexpr auto kTranspose = "Transpose";
@ -233,6 +235,8 @@ inline const PrimitivePtr kPrimResizeGrad = std::make_shared<Primitive>("ResizeG
inline const PrimitivePtr kPrimResizeNearestNeighbor = std::make_shared<Primitive>("ResizeNearestNeighbor");
inline const PrimitivePtr kPrimSort = std::make_shared<Primitive>("Sort");
inline const PrimitivePtr kPrimMaskedSelect = std::make_shared<Primitive>("MaskedSelect");
inline const PrimitivePtr kPrimDiag = std::make_shared<Primitive>(kDiag);
inline const PrimitivePtr kPrimDiagPart = std::make_shared<Primitive>(kDiagPart);
// NN
inline const PrimitivePtr kPrimAdam = std::make_shared<Primitive>("Adam");

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <string>
#include <vector>
#include <memory>
#include "ops/diag.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr DiagInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
CheckAndConvertUtils::CheckInteger("input rank", input_shape.size(), kGreaterEqual, 1, primitive->name());
std::vector<int64_t> out_shape(input_shape);
out_shape.insert(out_shape.end(), input_shape.begin(), input_shape.end());
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr PartInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto x_dtype = input_args[0]->BuildType();
return CheckAndConvertUtils::CheckTensorTypeValid("input type", x_dtype, common_valid_types, primitive->name());
}
} // namespace
AbstractBasePtr DiagInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return abstract::MakeAbstract(DiagInferShape(primitive, input_args), PartInferType(primitive, input_args));
}
} // namespace ops
} // namespace mindspore

39
mindspore/core/ops/diag.h Normal file
View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_DIAG_H_
#define MINDSPORE_CORE_OPS_DIAG_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
class Diag : public PrimitiveC {
public:
Diag() : PrimitiveC(prim::kPrimDiag->name()) { InitIOName({"input_x"}, {"output"}); }
~Diag() = default;
MS_DECLARE_PARENT(Diag, PrimitiveC);
};
AbstractBasePtr DiagInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimDiagPtr = std::shared_ptr<Diag>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_DIAG_H_

View File

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <string>
#include <vector>
#include <memory>
#include "ops/diag_part.h"
#include "ops/op_utils.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kScaleNum = 2;
abstract::ShapePtr DiagPartInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
if ((input_shape.size() % kScaleNum) != 0 || input_shape.size() == 0) {
MS_EXCEPTION(ValueError) << "For DiagPart, input rank must be non-zero and even, but got rank "
<< input_shape.size();
}
auto length = input_shape.size() / kScaleNum;
std::vector<int64_t> out_shape;
for (size_t i = 0; i < length; i++) {
CheckAndConvertUtils::Check("input_shape[i + rank(input_shape) / 2]", input_shape[i + length], kEqual,
"input_shape[i]", input_shape[i], op_name, ValueError);
out_shape.emplace_back(input_shape[i]);
}
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr DiagPartInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto x_dtype = input_args[0]->BuildType();
return CheckAndConvertUtils::CheckTensorTypeValid("input type", x_dtype, common_valid_types, primitive->name());
}
} // namespace
AbstractBasePtr DiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
return abstract::MakeAbstract(DiagPartInferShape(primitive, input_args), DiagPartInferType(primitive, input_args));
}
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_DIAG_PART_H_
#define MINDSPORE_CORE_OPS_DIAG_PART_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
class DiagPart : public PrimitiveC {
public:
DiagPart() : PrimitiveC(prim::kPrimDiagPart->name()) { InitIOName({"input_x"}, {"output"}); }
~DiagPart() = default;
MS_DECLARE_PARENT(DiagPart, PrimitiveC);
};
AbstractBasePtr DiagPartInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimDiagPartPtr = std::shared_ptr<DiagPart>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_DIAG_PART_H_

View File

@ -414,8 +414,9 @@ abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &
return shape;
}
void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, const string &,
int64_t value, const string &prim_name, ExceptionType exception_type) {
void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type,
const string &value_name, int64_t value, const string &prim_name,
ExceptionType exception_type) {
auto iter = kCompareMap<float>.find(compare_type);
if (iter == kCompareMap<float>.end()) {
MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map";
@ -433,8 +434,8 @@ void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, Comp
if (iter_to_string == kCompareToString.end()) {
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map";
}
MS_EXCEPTION(exception_type) << buffer.str() << arg_name << " should be " << iter_to_string->second << value
<< " but got " << arg_value;
MS_EXCEPTION(exception_type) << buffer.str() << arg_name << " should be " << iter_to_string->second << value_name
<< ": " << value << ", but got " << arg_value;
}
TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,

View File

@ -337,6 +337,8 @@ from .inplace_update import _inplace_update_tbe
from .splitv import _split_v_tbe
from .in_top_k import _in_top_k_tbe
from .lin_space import _lin_space_tbe
from .diag import _diag_tbe
from .diag_part import _diag_part_tbe
from .matrix_diag import _matrix_diag_tbe
from .matrix_diag_part import _matrix_diag_part_tbe
from .matrix_set_diag import _matrix_set_diag_tbe

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DiagD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
diag_d_op_info = TBERegOp("Diag") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("diag_d.so") \
.compute_cost(10) \
.kernel_name("diag_d") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "assist", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(diag_d_op_info)
def _diag_tbe():
"""DiagD TBE register"""
return

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DiagPartD op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
diag_part_d_op_info = TBERegOp("DiagPart") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("diag_part_d.so") \
.compute_cost(10) \
.kernel_name("diag_part_d") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "assist", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(diag_part_d_op_info)
def _diag_part_tbe():
"""DiagPartD TBE register"""
return

View File

@ -3285,6 +3285,13 @@ class Diag(PrimitiveWithInfer):
Outputs:
Tensor, has the same dtype as the `input_x`.
Raises:
TypeError: If `input_x` is not a Tensor.
ValueError: If rank of `input_x` is less than 1.
Supported Platforms:
``Ascend``
Examples:
>>> input_x = Tensor([1, 2, 3, 4])
>>> diag = ops.Diag()
@ -3330,11 +3337,19 @@ class DiagPart(PrimitiveWithInfer):
:math:`output[i_1,..., i_k] = input[i_1,..., i_k, i_1,..., i_k]`.
Inputs:
- **input_x** (Tensor) - tensor of rank k where k is even and not zero.
- **input_x** (Tensor) - The input tensor of rank 2k, k is not zero.
Outputs:
Tensor, the extracted diagonal has the same dtype as the `input_x`.
Raises:
TypeError: If `input_x` is not a Tensor.
ValueError: If rank of `input_x` is not even or zero.
ValueError: If input_shape[i] is not equal to input_shape[i + len(input_shape)/2].
Supported Platforms:
``Ascend``
Examples
>>> input_x = Tensor([[1, 0, 0, 0],
... [0, 2, 0, 0],