!45367 Graph Kernel adapt convert input attr about graph kernel

Merge pull request !45367 from ZengZitao/convert_input_attr_gk
This commit is contained in:
i-robot 2022-11-17 02:19:12 +00:00 committed by Gitee
commit 01e44b24a8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
66 changed files with 1030 additions and 1013 deletions

View File

@ -1,7 +1,7 @@
mindspore.ops.ReduceSum
=========================
.. py:class:: mindspore.ops.ReduceSum(keep_dims=False)
.. py:class:: mindspore.ops.ReduceSum(keep_dims=False, skip_mode=False)
默认情况下输出Tensor各维度上的和以达到对所有维度进行归约的目的。也可以对指定维度进行求和归约。
@ -9,15 +9,18 @@ mindspore.ops.ReduceSum
参数:
- **keep_dims** (bool) - 如果为True则保留计算维度长度为1。如果为False则不保留计算维度。默认值False输出结果会降低维度。
- **skip_mode** (bool) - 如果为True并且axis为空tuple或空list不进行ReduceSum计算,axis为其他值正常运算。如果为False则正常进行运算。默认值False。
输入:
- **x** (Tensor[Number]) - ReduceSum的输入其数据类型为Number。shape :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。秩应小于8。
- **axis** (Union[int, tuple(int), list(int)]) - 要减少的维度。默认值: (),缩小所有维度。只允许常量值,取值范围[-rank(`x`), rank(`x`))。
- **axis** (Union[int, tuple(int), list(int)]) - 要减少的维度。默认值: ()当skip_mode为False时缩小所有维度。只允许常量值,取值范围[-rank(`x`), rank(`x`))。
输出:
Tensor具有与输入 `x` 相同的shape。
- 如果轴为()且keep_dims为False则输出一个0维Tensor表示输入Tensor中所有元素的和。
- 如果轴为()且keep_dims为Falseskip_mode为False则输出一个0维Tensor表示输入Tensor中所有元素的和。
- 如果轴为()且skip_mode为True则不进行ReduceSum运算输出Tensor等于输入Tensor。
- 如果轴为int取值为2并且keep_dims为False则输出的shape为 :math:`(x_1, x_3, ..., x_R)`
@ -25,5 +28,6 @@ mindspore.ops.ReduceSum
异常:
- **TypeError** - `keep_dims` 不是bool。
- **TypeError** - `skip_mode` 不是bool。
- **TypeError** - `x` 不是Tensor。
- **ValueError** - `axis` 取值为None。

View File

@ -307,12 +307,7 @@ void InferOp(const CNodePtr &cnode, void *args) {
MS_EXCEPTION_IF_NULL(kernel_mod);
kernel::KernelArgs kernel_args;
if (AnfAlgo::IsDynamicShapeSkipExecute(cnode)) {
std::vector<TypeId> dtypes{common::AnfAlgo::GetOutputInferDataType(cnode, 0)};
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetInputDeviceShape(cnode, 0)}, cnode.get());
} else {
InferShape(cnode, &kernel_args.depend_tensor_map, args);
}
if (auto kernel_mod_type = kernel_mod->GetKernelModType(); IsCpuGpuKernelMod(kernel_mod_type)) {
auto update = kernel::AbstractArgsFromCNode(cnode, IsDeprecatedCpuOrGpuKernelMod(kernel_mod_type));

View File

@ -257,12 +257,21 @@ tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_pt
return tensor;
}
tensor::TensorPtr CreateEmptyTupleTensor() {
std::vector<int64_t> tensor_shape = {0};
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), tensor_shape);
MS_EXCEPTION_IF_NULL(tensor);
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, kInt64};
tensor->set_device_info(device_info);
return tensor;
}
tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
MS_EXCEPTION_IF_NULL(value_tuple);
tensor::TensorPtr tensor = nullptr;
if (value_tuple->value().empty()) {
MS_LOG(WARNING) << "The value tuple is empty.";
return nullptr;
tensor = CreateEmptyTupleTensor();
return tensor;
}
ValuePtr v = *(value_tuple->value().begin());
MS_EXCEPTION_IF_NULL(v);

View File

@ -1355,12 +1355,9 @@ void AnfRuntimeAlgorithm::UpdateGraphValidRefPair(const KernelGraphPtr &graph) {
graph->set_ref_out_in_map(new_ref_map);
}
bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const std::string &op_name, const ShapeVector &axes_shape) {
bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const bool skip_mode, const ShapeVector &axes_shape) {
// Skip run ReduceSum when axis is a Empty Tensor
if (op_name != kReduceSumOpName) {
return false;
}
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; })) {
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
return true;
}
return false;
@ -1374,6 +1371,11 @@ bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const CNodePtr &cnode) {
return false;
}
bool skip_mode = false;
if (common::AnfAlgo::HasNodeAttr(kAttrSkipMode, cnode)) {
skip_mode = common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrSkipMode);
}
const size_t axes_index = 1;
if (cnode->inputs().size() <= axes_index + 1) {
return false;
@ -1387,7 +1389,7 @@ bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(axes_abs);
auto axes_shape = AnfAlgo::GetInputDeviceShape(cnode, axes_index);
if (axes_abs->isa<abstract::AbstractTensor>()) {
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; })) {
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
return true;
}
}

View File

@ -172,7 +172,7 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
static void UpdateGraphValidRefPair(const KernelGraphPtr &graph);
static bool IsDynamicShapeSkipExecute(const std::string &op_name, const ShapeVector &axes_shape);
static bool IsDynamicShapeSkipExecute(const bool skip_mode, const ShapeVector &axes_shape);
static bool IsDynamicShapeSkipExecute(const CNodePtr &cnode);
// return true if need to update output's shape and type after launch
static bool IsNeedUpdateShapeAndTypeAfterLaunch(const AnfNodePtr &cnode);

View File

@ -30,6 +30,7 @@
#include "common/graph_kernel/graph_kernel_flags.h"
#include "common/graph_kernel/adapter/callback_impl.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
#include "common/graph_kernel/core/convert_op_input_attr.h"
#include "backend/common/pass/inplace_assign_for_custom_op.h"
#include "kernel/common_utils.h"
#include "utils/ms_context.h"
@ -65,7 +66,7 @@ ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
{prim::kPrimArgMinWithValue->name(), {ArgWithValueDeco::Creator}},
{prim::kPrimSolveTriangular->name(), {ProcessCustomOpDeco::Creator}},
{prim::kPrimLU->name(), {ProcessCustomOpDeco::Creator}},
{prim::kPrimVmapUnstackAssign->name(), {AttrToInputDeco::GetCreator(true)}},
{prim::kPrimVmapUnstackAssign->name(), {AttrToInputDeco::Creator}},
};
const auto iter = creators.find(GetCNodePrimitive(node)->name());
if (iter != creators.end()) {
@ -178,7 +179,7 @@ AnfNodePtr TryExpandCNode(const AnfNodePtr &node, const std::function<bool(const
return nullptr;
}
auto expander = GetExpander(node);
expander = AttrToInputDeco::GetCreator(true)(expander);
expander = AttrToInputDeco::Creator(expander);
auto res = expander->Run(node);
auto expand_fg = GetCNodeFuncGraph(res);
if (expand_fg != nullptr) {
@ -240,49 +241,6 @@ AnfNodePtr SetDynamicShapeAttrDeco::Run(const AnfNodePtr &node) {
return new_cnode;
}
void AttrToInputDeco::ConvertAttrToInput(const FuncGraphPtr &graph) {
auto todos = TopoSort(graph->get_return());
for (const auto &node : todos) {
auto primitive = GetCNodePrimitive(node);
if (primitive == nullptr) {
continue;
}
std::map<std::string, std::set<size_t>> attr2input_map;
if (is_backend_) {
attr2input_map = std::map<std::string, std::set<size_t>>{{prim::kPrimTupleGetItem->name(), {1}}};
} else {
attr2input_map = std::map<std::string, std::set<size_t>>{
{prim::kPrimCast->name(), {1}}, {prim::kPrimReshape->name(), {1}}, {prim::kPrimReduceMax->name(), {1}},
{prim::kPrimReduceMin->name(), {1}}, {prim::kPrimReduceSum->name(), {1}}, {prim::kPrimTranspose->name(), {1}},
{prim::kPrimTupleGetItem->name(), {1}}};
}
if (attr2input_map.count(primitive->name()) != 0) {
auto input_names = primitive->GetAttr(kAttrInputNames);
auto cnode = dyn_cast<CNode>(node);
AnfNodePtrList inputs = cnode->inputs();
AnfNodePtrList new_inputs{inputs[0]};
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
auto attrs_map = attr2input_map[primitive->name()];
size_t j = 1;
for (size_t i = 0; i < input_names_vec.size(); ++i) {
if (attrs_map.count(i) != 0) {
auto value = primitive->GetAttr(input_names_vec[i]);
auto value_node = std::make_shared<ValueNode>(value);
value_node->set_abstract(value->ToAbstract());
new_inputs.push_back(value_node);
} else {
if (j >= inputs.size()) {
MS_LOG(EXCEPTION) << "Index " << j << " is larger than input size [" << inputs.size() << "]";
}
new_inputs.push_back(inputs[j]);
j++;
}
}
cnode->set_inputs(new_inputs);
}
}
}
AnfNodePtr AttrToInputDeco::Run(const AnfNodePtr &node) {
auto new_node = decorated_->Run(node);
if (new_node == nullptr) {
@ -290,7 +248,10 @@ AnfNodePtr AttrToInputDeco::Run(const AnfNodePtr &node) {
}
auto new_cnode = dyn_cast<CNode>(new_node);
auto expand_fg = GetCNodeFuncGraph(new_cnode);
ConvertAttrToInput(expand_fg);
auto todos = TopoSort(expand_fg->get_return());
for (const auto &node : todos) {
ConvertOpUtils::ConvertAttrToInput(node);
}
new_cnode->set_input(0, NewValueNode(expand_fg));
return new_cnode;
}

View File

@ -75,19 +75,12 @@ class SetDynamicShapeAttrDeco : public ExpanderDecorator {
class COMMON_EXPORT AttrToInputDeco : public ExpanderDecorator {
public:
AttrToInputDeco(const ExpanderPtr &decorated, bool is_backend)
: ExpanderDecorator(decorated), is_backend_(is_backend) {}
explicit AttrToInputDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
~AttrToInputDeco() override = default;
static ExpanderCreatorFunc GetCreator(bool is_backend) {
return [is_backend](const ExpanderPtr &decorated) {
return std::static_pointer_cast<Expander>(std::make_shared<AttrToInputDeco>(decorated, is_backend));
};
static ExpanderPtr Creator(const ExpanderPtr &decorated) {
return std::static_pointer_cast<Expander>(std::make_shared<AttrToInputDeco>(decorated));
}
AnfNodePtr Run(const AnfNodePtr &node) override;
protected:
void ConvertAttrToInput(const FuncGraphPtr &graph);
bool is_backend_{false};
};
/**

View File

@ -0,0 +1,142 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/ccsrc/common/graph_kernel/core/convert_op_input_attr.h"
#include "mindspore/ccsrc/common/graph_kernel/core/graph_kernel_callback.h"
#include "mindspore/core/ops/core_ops.h"
#include "utils/anf_utils.h"
#include "utils/ms_context.h"
#include "utils/check_convert_utils.h"
namespace mindspore::graphkernel {
std::map<std::string, HashSet<size_t>> ConvertOpUtils::op_idx_info_ = {{prim::kPrimReshape->name(), {1}},
{prim::kPrimReduceMax->name(), {1}},
{prim::kPrimExpandDims->name(), {1}},
{prim::kPrimReduceMin->name(), {1}},
{prim::kPrimReduceSum->name(), {1}},
{prim::kPrimTranspose->name(), {1}},
{prim::kPrimTile->name(), {1}},
{prim::kPrimReduceMean->name(), {1}},
{prim::kPrimSlice->name(), {1, 2}},
{prim::kPrimStridedSlice->name(), {1, 2, 3}},
{prim::kPrimOneHot->name(), {1}},
{prim::kPrimReduceFusion->name(), {1}},
{prim::kPrimConstantOfShape->name(), {0}},
{prim::kPrimGather->name(), {2}},
{prim::kPrimTupleGetItem->name(), {1}},
{prim::kPrimUnsortedSegmentSum->name(), {2}},
{prim::kPrimCumSum->name(), {1}}};
bool ConvertOpUtils::ConstInputToAttr(const CNodePtr &cnode, const HashSet<size_t> &input_idx) {
AnfNodePtrList new_inputs;
auto primitive = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
auto input_names = primitive->GetAttr(kAttrInputNames);
if (input_names == nullptr) {
MS_LOG(INFO) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
return false;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
auto inputs = cnode->inputs();
new_inputs.push_back(inputs[0]);
for (size_t i = 0; i < inputs.size() - 1; ++i) {
auto input_node = inputs[i + 1];
MS_EXCEPTION_IF_NULL(input_node);
if (input_idx.count(i) != 0) {
if (i >= input_names_vec.size()) {
MS_LOG(INFO) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
return false;
}
ValuePtr value = nullptr;
if (input_node->isa<ValueNode>()) {
auto value_node = input_node->cast<ValueNodePtr>();
value = value_node->value();
} else if (input_node->isa<Parameter>()) {
auto parameter_node = input_node->cast<ParameterPtr>();
value = parameter_node->abstract()->BuildValue();
}
if (value != nullptr) {
auto value_vector = CheckAndConvertUtils::CheckTensorIntValue(input_names_vec[i], value, primitive->name());
auto tensor = value->cast<tensor::TensorPtr>();
auto tensor_shape = tensor->shape_c();
if (tensor_shape.empty()) {
primitive->set_attr(input_names_vec[i], MakeValue(value_vector[0]));
} else {
primitive->set_attr(input_names_vec[i], MakeValue(value_vector));
}
} else {
MS_LOG(ERROR) << input_names_vec[i] << "'s Value is null!";
}
} else {
new_inputs.push_back(inputs[i + 1]);
}
}
if (new_inputs.size() != inputs.size()) {
cnode->set_inputs(new_inputs);
}
return true;
}
void ConvertOpUtils::ConvertAttrToInput(const AnfNodePtr &node) {
if (Callback::Instance()->GetTargetFromContext() == kAscendDevice) {
return;
}
auto cnode = dyn_cast<CNode>(node);
auto primitive = GetCNodePrimitive(cnode);
if (primitive == nullptr) {
return;
}
auto attr2input_map = ConvertOpUtils::GetOpIndexInfo();
if (attr2input_map.count(primitive->name()) != 0) {
auto input_names = primitive->GetAttr(kAttrInputNames);
AnfNodePtrList inputs = cnode->inputs();
AnfNodePtrList new_inputs{inputs[0]};
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
auto attrs_map = attr2input_map[primitive->name()];
size_t j = 1;
for (size_t i = 0; i < input_names_vec.size(); ++i) {
if (attrs_map.count(i) != 0) {
auto value = primitive->GetAttr(input_names_vec[i]);
ValueNodePtr value_node;
// Adaptive for TupleGetItem, this op can not make attr to a tensor input.
if (primitive->name() == prim::kPrimTupleGetItem->name()) {
value_node = std::make_shared<ValueNode>(value);
value_node->set_abstract(value->ToAbstract());
} else if (value->isa<Scalar>()) {
auto value_scalar = GetValue<int64_t>(value);
auto tensor_ptr = std::make_shared<tensor::Tensor>(value_scalar, kInt64);
value_node = std::make_shared<ValueNode>(tensor_ptr);
value_node->set_abstract(tensor_ptr->ToAbstract());
} else if (value->isa<ValueTuple>()) {
auto value_tuple = GetValue<std::vector<int64_t>>(value);
auto tensor_ptr = std::make_shared<tensor::Tensor>(value_tuple, kInt64);
value_node = std::make_shared<ValueNode>(tensor_ptr);
value_node->set_abstract(tensor_ptr->ToAbstract());
} else {
MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
}
Callback::Instance()->SetEmptyKernelInfo(value_node);
new_inputs.push_back(value_node);
} else {
if (j >= inputs.size()) {
MS_LOG(EXCEPTION) << "Index " << j << " is larger than input size [" << inputs.size() << "]";
}
new_inputs.push_back(inputs[j]);
j++;
}
}
cnode->set_inputs(new_inputs);
}
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_CONVERT_OP_INPUT_ATTR_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_CONVERT_OP_INPUT_ATTR_H_
#include <string>
#include <tuple>
#include <vector>
#include <map>
#include <memory>
#include "mindspore/core/ops/core_ops.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "utils/hash_map.h"
namespace mindspore::graphkernel {
class ConvertOpUtils {
public:
static bool ConstInputToAttr(const CNodePtr &cnode, const HashSet<size_t> &input_idx);
static void ConvertAttrToInput(const AnfNodePtr &node);
static std::map<std::string, HashSet<size_t>> &GetOpIndexInfo() { return op_idx_info_; }
static std::map<std::string, HashSet<size_t>> op_idx_info_;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_CONVERT_OP_INPUT_ATTR_H_

View File

@ -20,6 +20,7 @@
#include "utils/anf_utils.h"
#include "common/graph_kernel/core/graph_kernel_callback.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
#include "common/graph_kernel/core/convert_op_input_attr.h"
#include "common/graph_kernel/expanders/op_desc_registry.h"
namespace mindspore::graphkernel {
@ -47,42 +48,15 @@ CNodePtr ExpanderDecorator::QuickCloneCNode(const AnfNodePtr &node, bool clone_p
return new_node;
}
bool InputToAttrDeco::ConstInputToAttr(const CNodePtr &cnode) const {
AnfNodePtrList new_inputs;
auto primitive = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
auto input_names = primitive->GetAttr(kAttrInputNames);
if (input_names == nullptr) {
MS_LOG(INFO) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
return false;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
auto inputs = cnode->inputs();
new_inputs.push_back(inputs[0]);
for (size_t i = 0; i < inputs.size() - 1; ++i) {
auto input_node = inputs[i + 1];
MS_EXCEPTION_IF_NULL(input_node);
if (input_idx_.count(i) != 0 && input_node->isa<ValueNode>()) {
auto value_node = input_node->cast<ValueNodePtr>();
if (i >= input_names_vec.size()) {
MS_LOG(INFO) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
return false;
}
auto value = value_node->value();
primitive->set_attr(input_names_vec[i], value);
} else {
new_inputs.push_back(inputs[i + 1]);
}
}
if (new_inputs.size() != inputs.size()) {
cnode->set_inputs(new_inputs);
}
return true;
}
AnfNodePtr InputToAttrDeco::Run(const AnfNodePtr &node) {
auto cnode = QuickCloneCNode(node, true);
auto ret = ConstInputToAttr(cnode);
auto all_op_index_info = ConvertOpUtils::GetOpIndexInfo();
auto iter = all_op_index_info.find(GetCNodePrimitive(node)->name());
if (iter == all_op_index_info.end()) {
MS_LOG(DEBUG) << "This Node don't need convert input to attr";
return nullptr;
}
auto ret = ConvertOpUtils::ConstInputToAttr(cnode, iter->second);
return ret ? decorated_->Run(cnode) : nullptr;
}

View File

@ -65,20 +65,12 @@ using ExpanderCreatorFuncList = std::vector<ExpanderCreatorFunc>;
class COMMON_EXPORT InputToAttrDeco : public ExpanderDecorator {
public:
InputToAttrDeco(const ExpanderPtr &decorated, const HashSet<size_t> &input_idx)
: ExpanderDecorator(decorated), input_idx_(input_idx) {}
explicit InputToAttrDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
~InputToAttrDeco() = default;
AnfNodePtr Run(const AnfNodePtr &node) override;
static ExpanderCreatorFunc GetCreator(const HashSet<size_t> &input_idx) {
return [input_idx](const ExpanderPtr &decorated) {
return std::static_pointer_cast<Expander>(std::make_shared<InputToAttrDeco>(decorated, input_idx));
};
static ExpanderPtr Creator(const ExpanderPtr &decorated) {
return std::static_pointer_cast<Expander>(std::make_shared<InputToAttrDeco>(decorated));
}
protected:
bool ConstInputToAttr(const CNodePtr &cnode) const;
HashSet<size_t> input_idx_;
AnfNodePtr Run(const AnfNodePtr &node) override;
};
/**

View File

@ -32,5 +32,6 @@ AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const Fu
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
bool RemoveNonScalarConstTensorFromParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
bool EliminateMaketupleGetitem(const FuncGraphPtr &fg);
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_BUILDER_H_

View File

@ -24,12 +24,14 @@
#include "utils/hash_map.h"
#include "mindspore/core/ops/core_ops.h"
#include "ir/graph_utils.h"
#include "include/common/utils/anfalgo.h"
#include "utils/anf_utils.h"
#include "utils/ms_context.h"
#include "utils/file_utils.h"
#include "common/graph_kernel/graph_kernel_flags.h"
#include "common/graph_kernel/core/graph_kernel_callback.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
#include "common/graph_kernel/core/convert_op_input_attr.h"
#include "common/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
@ -430,12 +432,41 @@ bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
}
return changed;
}
void GraphKernelCluster::ConvertInputToAttrForClusterGraph(const AnfNodePtr &graph_kernel_node) {
auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(graph_kernel_node);
MS_EXCEPTION_IF_NULL(gk_graph);
auto todos = TopoSort(gk_graph->get_return());
for (const auto &n : todos) {
if (n == nullptr) {
continue;
}
auto node = n->cast<CNodePtr>();
if (node != nullptr) {
auto all_op_index_info = ConvertOpUtils::GetOpIndexInfo();
auto primitive = GetCNodePrimitive(node);
if (primitive == nullptr) {
continue;
}
auto iter = all_op_index_info.find(primitive->name());
if (iter != all_op_index_info.end()) {
(void)ConvertOpUtils::ConstInputToAttr(node, iter->second);
}
}
}
// Redcuce nouesr inputs for new node
auto cnode = graph_kernel_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs();
AnfNodePtrList fn_inputs{inputs};
EliminateRedundantParameters(gk_graph, &fn_inputs);
cnode->set_inputs(fn_inputs);
}
void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
AnfNodePtrList old_nodes;
(void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
[this](size_t id) { return this->nodes_[id]; });
auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
ConvertInputToAttrForClusterGraph(new_node);
if (GraphKernelFlags::GetInstance().dump_as_text) {
DumpClusterInfo(old_nodes, new_node);
}

View File

@ -45,6 +45,7 @@ class GraphKernelCluster : public opt::Pass {
std::vector<size_t> FindCandidates(size_t basenode_id);
void RemoveWildGetitem(std::vector<size_t> *candidates);
void CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id);
void ConvertInputToAttrForClusterGraph(const AnfNodePtr &graph_kernel_node);
void DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node);
void DumpToFile();
void Clean() {

View File

@ -23,6 +23,7 @@
#include "utils/anf_utils.h"
#include "common/graph_kernel/core/graph_builder.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
#include "common/graph_kernel/core/convert_op_input_attr.h"
namespace mindspore::graphkernel {
AnfNodePtr GraphKernelExpander::CreateExpandedNode(const CNodePtr &node, const std::string &name) const {
@ -51,7 +52,11 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
!AnfUtils::IsRealKernel(node) || !CanExpand(node)) {
continue;
}
auto all_op_index_info = ConvertOpUtils::GetOpIndexInfo();
auto iter = all_op_index_info.find(GetCNodePrimitive(node)->name());
if (iter != all_op_index_info.end()) {
(void)ConvertOpUtils::ConstInputToAttr(node, iter->second);
}
MS_LOG(DEBUG) << "Expanding node: " << node->fullname_with_scope();
auto newnode = InitExpander(node)->Run(node);
if (newnode == nullptr) {

View File

@ -27,6 +27,7 @@
#include "kernel/akg/akg_kernel_json_generator.h"
#include "common/graph_kernel/core/graph_kernel_callback.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
#include "common/graph_kernel/core/convert_op_input_attr.h"
#include "common/graph_kernel/split_model/split_model_factory.h"
namespace mindspore::graphkernel {
@ -367,6 +368,7 @@ class Splitter {
private:
void ResetInlinedNodesKernelInfo() const {
for (const auto &node : inlined_nodes_) {
ConvertOpUtils::ConvertAttrToInput(node);
Callback::Instance()->ResetKernelInfo(node);
}
}

View File

@ -988,11 +988,21 @@ std::vector<DShape> StandardNormalOp::InferShape(const NodePtrList &, const DAtt
}
void StridedSliceOp::RectifyAbstract(const PrimitivePtr &primitive, AbstractBasePtrList *inputs_abstract) {
if (!primitive->HasAttr("new_axis_mask")) {
(void)primitive->AddAttr("new_axis_mask", MakeValue((int64_t(0))));
}
if (!primitive->HasAttr("shrink_axis_mask")) {
(void)primitive->AddAttr("shrink_axis_mask", MakeValue((int64_t(0))));
}
if (!primitive->HasAttr("end_mask")) {
(void)primitive->AddAttr("end_mask", MakeValue((int64_t(0))));
}
if (!primitive->HasAttr("begin_mask")) {
(void)primitive->AddAttr("begin_mask", MakeValue((int64_t(0))));
}
if (!primitive->HasAttr("ellipsis_mask")) {
(void)primitive->AddAttr("ellipsis_mask", MakeValue((int64_t(0))));
}
const mindspore::HashSet<size_t> &convert_input_list{1, 2, 3};
auto input_names = primitive->GetAttr(kAttrInputNames);
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
@ -1181,4 +1191,38 @@ void TupleGetItemOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrLi
(void)abs_list->emplace_back(prim->GetAttr("index")->ToAbstract());
}
}
void ReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
if (prim->HasAttr("axis")) {
(void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract());
}
}
void ReshapeOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
if (prim->HasAttr("shape")) {
(void)abs_list->emplace_back(prim->GetAttr("shape")->ToAbstract());
} else if (prim->HasAttr("input_shape")) {
(void)abs_list->emplace_back(prim->GetAttr("input_shape")->ToAbstract());
}
}
void TransposeOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
if (prim->HasAttr("input_perm")) {
(void)abs_list->emplace_back(prim->GetAttr("input_perm")->ToAbstract());
} else if (prim->HasAttr("perm")) {
(void)abs_list->emplace_back(prim->GetAttr("perm")->ToAbstract());
}
}
void TileOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
if (prim->HasAttr("multiples")) {
(void)abs_list->emplace_back(prim->GetAttr("multiples")->ToAbstract());
}
}
void OneHotOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
if (prim->HasAttr("depth")) {
(void)abs_list->emplace_back(prim->GetAttr("depth")->ToAbstract());
}
}
} // namespace mindspore::graphkernel::inner

View File

@ -96,6 +96,7 @@ class ReshapeOp : public PrimOp {
~ReshapeOp() = default;
protected:
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list);
DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override {
return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
: GetValue<std::string>(attrs.find("format")->second);
@ -118,12 +119,22 @@ class BroadcastOp : public PrimOp {
~BroadcastOp() = default;
};
class TileOp : public BroadcastOp {
public:
explicit TileOp(const std::string &op) : BroadcastOp(op) {}
~TileOp() = default;
protected:
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) override;
};
class ReduceOp : public PrimOp {
public:
explicit ReduceOp(const std::string &op) : PrimOp(op, ComputeType::REDUCE) {}
~ReduceOp() = default;
protected:
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) override;
DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; };
};
@ -159,9 +170,18 @@ class TransposeOp : public OpaqueOp {
~TransposeOp() = default;
protected:
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) override;
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
};
class OneHotOp : public OpaqueOp {
public:
explicit OneHotOp(const std::string &op) : OpaqueOp(op) {}
~OneHotOp() = default;
protected:
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) override;
};
class LayoutTransformOp : public OpaqueOp {
public:
explicit LayoutTransformOp(const std::string &op) : OpaqueOp(op) {}

View File

@ -109,18 +109,18 @@ AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node) {
}
auto expander = graphkernel::GetExpander(node);
std::map<std::string, graphkernel::ExpanderCreatorFuncList> creators = {
{prim::kPrimExpandDims->name(), {InputToAttrDeco::GetCreator({1})}},
{prim::kPrimReshape->name(), {InputToAttrDeco::GetCreator({1})}},
{prim::kPrimReduceMean->name(), {InputToAttrDeco::GetCreator({1})}},
{prim::kPrimGather->name(), {InputToAttrDeco::GetCreator({2})}},
{kTileOpName, {InputToAttrDeco::GetCreator({1})}},
{kSliceOpName, {InputToAttrDeco::GetCreator({1, 2})}},
{prim::kPrimExpandDims->name(), {InputToAttrDeco::Creator}},
{prim::kPrimReshape->name(), {InputToAttrDeco::Creator}},
{prim::kPrimReduceMean->name(), {InputToAttrDeco::Creator}},
{prim::kPrimGather->name(), {InputToAttrDeco::Creator}},
{kTileOpName, {InputToAttrDeco::Creator}},
{kSliceOpName, {InputToAttrDeco::Creator}},
};
const auto &iter = creators.find(GetCNodePrimitive(node)->name());
if (iter != creators.end()) {
expander = graphkernel::WrapExpander(expander, iter->second);
}
expander = graphkernel::AttrToInputDeco::GetCreator(false)(expander);
expander = graphkernel::AttrToInputDeco::Creator(expander);
expander = PrimToPrimPyDecorator::Creator(expander);
auto new_node = expander->Run(node);
auto expand_fg = GetCNodeFuncGraph(new_node);

View File

@ -481,6 +481,7 @@ constexpr auto kAttrReshapeType = "reshape_type";
constexpr auto kAttrAxis = "axis";
constexpr auto kAttrAxes = "axes";
constexpr auto kAttrKeepDims = "keep_dims";
constexpr auto kAttrSkipMode = "skip_mode";
constexpr auto kAttrShapeGamma = "shape_gamma";
constexpr auto kAttrPerm = "perm";
constexpr auto kAttrTransposeFirst = "transpose_first";

View File

@ -35,6 +35,7 @@
#include "include/common/utils/convert_utils.h"
#include "include/common/utils/convert_utils_py.h"
#include "include/common/utils/utils.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore {
namespace kernel {
@ -76,6 +77,7 @@ class CNodeDecoder {
}
CreateKernelInfo(processor);
CreateAbstract();
InitIOName(cnode_);
return cnode_;
}
@ -109,6 +111,21 @@ class CNodeDecoder {
}
}
void InitIOName(const CNodePtr &cnode) {
auto primitive = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
const auto &op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
auto const iter = op_primc_fns.find(primitive->name());
if (iter == op_primc_fns.end()) {
return;
}
auto prim = iter->second();
if (prim != nullptr) {
(void)primitive->AddAttr(kAttrInputNames, prim->GetAttr(kAttrInputNames));
(void)primitive->AddAttr(kAttrOutputNames, prim->GetAttr(kAttrOutputNames));
}
}
bool DecodeAttrs(const nlohmann::json &attrs_json) {
MS_LOG(DEBUG) << "start decode attrs, " << attrs_json;
// attrs maybe empty

View File

@ -50,7 +50,7 @@ void EmbeddingLookUpCommGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node
if (IsDynamic(input_shape_)) {
return;
}
if (input_shape_.size() < 1) {
if (input_shape_.empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of input must be at least 1-D, but got: " << input_shape_.size() << "-D";
}

View File

@ -23,7 +23,7 @@ namespace kernel {
namespace {
constexpr size_t kEmbeddingLookupInputsNum = 3;
constexpr size_t kEmbeddingLookUpInputParamsMaxDim = 2;
constexpr size_t kIndex2 = 2;
constexpr size_t kOffsetIndex = 2;
using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc;
#define ADD_KERNEL(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, input_indices_type) \
@ -152,7 +152,7 @@ bool EmbeddingLookUpCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
T *input_params_addr = reinterpret_cast<T *>(inputs[0]->addr);
S *input_indices_addr = reinterpret_cast<S *>(inputs[1]->addr);
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
G offset = static_cast<G *>(inputs[kIndex2]->addr)[0];
G offset = static_cast<G *>(inputs[kOffsetIndex]->addr)[0];
offset_ = static_cast<int64_t>(offset);
auto task = [&](size_t start, size_t end) {

View File

@ -112,70 +112,8 @@ std::vector<KernelAttr> MemcpyCpuKernelMod::common_two_valid_types_with_bool_com
std::vector<KernelAttr> MemcpyCpuKernelMod::GetOpSupport() {
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
{kReshape,
{
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
// double input
// index int64
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
// index int32
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
// complex
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128),
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
}},
{kFlatten, common_valid_types_with_bool_complex_},
{kFlattenGrad, common_two_valid_types_with_bool_complex_},
{kExpandDims, common_two_valid_types_with_bool_complex_},
{kReshape, common_two_valid_types_with_bool_complex_}, {kFlatten, common_valid_types_with_bool_complex_},
{kFlattenGrad, common_two_valid_types_with_bool_complex_}, {kExpandDims, common_two_valid_types_with_bool_complex_},
{kSqueeze, common_valid_types_with_bool_complex_},
};

View File

@ -81,6 +81,7 @@ class ReduceCpuKernelFunc : public CpuKernelFunc {
bool simple_execute_{false};
std::string kernel_name_;
bool need_skip_execute_{false};
bool skip_mode_{false};
};
template <typename T>
@ -234,13 +235,11 @@ int ReduceCpuKernelFunc<T>::Resize(const BaseOperatorPtr &base_operator, const s
const std::vector<KernelTensorPtr> &,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
input_shape_ = inputs[0]->GetDeviceShapeAdaptively();
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
if (kernel_ptr->HasAttr(kAttrAxis)) {
axis_ = kernel_ptr->get_axis();
if (!TryGetIntValue(inputs, kIndex1, kernel_name_, &axis_, false)) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << " can't get axis input! ";
}
(void)TryGetIntValue(inputs, kAxisIndex_, kernel_name_, &axis_, false);
if (inputs.size() > kAxisIndex_ &&
AnfAlgo::IsDynamicShapeSkipExecute(kernel_name_, inputs[kAxisIndex_]->GetShapeVector())) {
AnfAlgo::IsDynamicShapeSkipExecute(skip_mode_, inputs[kAxisIndex_]->GetShapeVector())) {
need_skip_execute_ = true;
} else {
need_skip_execute_ = false;
@ -302,6 +301,8 @@ void ReduceCpuKernelFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, cons
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
ChooseFunc(kernel_name_);
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
skip_mode_ = kernel_ptr->get_skip_mode();
}
template <typename T>
@ -442,20 +443,8 @@ using SpecializeReduceFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>
#define REDUCE_CPU_REG(MS_T, MS_S, T) \
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddOutputAttr(MS_T), SpecializeReduceFunc<T>
static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_all_any_list = {
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeReduceFunc<bool>},
{REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt32, bool)},
{REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt64, bool)}};
{REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt32, bool)}, {REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt64, bool)}};
static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_max_min_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), SpecializeReduceFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), SpecializeReduceFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), SpecializeReduceFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), SpecializeReduceFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SpecializeReduceFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), SpecializeReduceFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), SpecializeReduceFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), SpecializeReduceFunc<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), SpecializeReduceFunc<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), SpecializeReduceFunc<uint64_t>},
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt32, float)},
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_CPU_REG(kNumberTypeFloat64, kNumberTypeInt32, double)},
@ -477,20 +466,6 @@ static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_ma
{REDUCE_CPU_REG(kNumberTypeUInt64, kNumberTypeInt32, uint64_t)},
{REDUCE_CPU_REG(kNumberTypeUInt64, kNumberTypeInt64, uint64_t)}};
static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_sum_prod_mean_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), SpecializeReduceFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), SpecializeReduceFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), SpecializeReduceFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), SpecializeReduceFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SpecializeReduceFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), SpecializeReduceFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), SpecializeReduceFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), SpecializeReduceFunc<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), SpecializeReduceFunc<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), SpecializeReduceFunc<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
SpecializeReduceFunc<complex64>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
SpecializeReduceFunc<complex128>},
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt32, float)},
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_CPU_REG(kNumberTypeFloat64, kNumberTypeInt32, double)},

View File

@ -25,8 +25,7 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kSliceInputsNum = 1;
constexpr size_t kSliceDynamicInputNum = 3;
constexpr size_t kSliceInputsNum = 3;
constexpr size_t kSliceOutputsNum = 1;
constexpr size_t kSliceInputIndex2 = 2;
constexpr size_t kSliceTwoDims = 2;
@ -59,12 +58,23 @@ bool SliceCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve
{kNumberTypeFloat16, sizeof(float16)},
{kNumberTypeComplex64, sizeof(std::complex<float>)},
{kNumberTypeComplex128, sizeof(std::complex<double>)}};
size_t input_num = inputs.size();
if (input_num != kSliceInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "input size should be " << kSliceInputsNum;
}
TypeId dtype = inputs[0]->GetDtype();
auto size_pair = type_size_map.find(dtype);
if (size_pair == type_size_map.end()) {
MS_LOG(EXCEPTION) << "Slice supports type in type_size_map, but got " << TypeIdToType(dtype)->ToString();
}
data_size_ = size_pair->second;
MS_EXCEPTION_IF_CHECK_FAIL(inputs.at(1)->GetDtype() == inputs[kSliceInputIndex2]->GetDtype(),
"Begin and size dtype should be same.");
param_dtype_ = inputs.at(1)->GetDtype();
return true;
}
@ -103,11 +113,17 @@ void SliceCpuKernelMod::InitSliceParam(const ShapeVector &input_shape, const std
int SliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Slice>(base_operator);
is_got_value_ = false;
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
size_t input_num = inputs.size();
if (input_num != kSliceInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "input size should be " << kSliceInputsNum;
}
auto input_shape = inputs[0]->GetShapeVector();
if (input_shape.size() > DIMENSION_8D || input_shape.empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
@ -115,23 +131,26 @@ int SliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::v
<< "D.";
}
size_t input_num = inputs.size();
// begin and size are const input
if (input_num == 1) {
auto size = kernel_ptr->get_size();
auto begin = kernel_ptr->get_begin();
std::vector<int64_t> begin;
std::vector<int64_t> size;
auto got_begin = TryGetIntValue(inputs, 1, kernel_name_, &begin, false);
auto got_size = TryGetIntValue(inputs, kSliceInputIndex2, kernel_name_, &size, false);
if (got_begin && got_size) {
if (begin.size() != input_shape.size() || size.size() != input_shape.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the lengths of 'begin' and 'size' must be equal to "
"the dimension of input tensor, but got the length of 'begin' "
<< begin.size() << ", the length of 'size' " << size.size()
<< "and the dimension of input tensor " << input_shape.size();
<< begin.size() << ", the length of 'size' " << size.size() << "and the dimension of input tensor "
<< input_shape.size();
return KRET_RESIZE_FAILED;
}
InitSliceParam(input_shape, begin, size);
} else if (input_num == kSliceDynamicInputNum) {
is_got_value_ = true;
} else {
input_shape_ = inputs[0]->GetShapeVector();
begin_shape_ = inputs[1]->GetShapeVector();
size_shape_ = inputs[kSliceInputIndex2]->GetShapeVector();
is_got_value_ = false;
}
return KRET_OK;
}
@ -150,45 +169,63 @@ void SliceSimpleDim2(const int8_t *input, int8_t *output, const SliceParameter *
bool SliceCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kSliceInputsNum && inputs.size() != kSliceDynamicInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kSliceInputsNum << " or "
<< kSliceDynamicInputNum << ", but got " << inputs.size() << " input(s).";
if (inputs.size() != kSliceInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kSliceInputsNum
<< ", but got " << inputs.size() << " input(s).";
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSliceOutputsNum, kernel_name_);
auto input_addr = inputs[0]->addr;
auto output_addr = outputs[0]->addr;
if (inputs.size() == kSliceDynamicInputNum) {
if (!is_got_value_) {
// Get begin and size value.
auto input_shape = input_shape_;
auto begin_shape = begin_shape_;
auto size_shape = size_shape_;
if (begin_shape.size() != 1 || size_shape.size() != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the dimensions of 'begin' and 'size' must be 1, but got the dimension of 'begin': "
<< begin_shape.size() << " and the dimension of 'size': " << size_shape.size();
return false;
}
if (begin_shape[0] != SizeToLong(input_shape.size()) || size_shape[0] != SizeToLong(input_shape.size())) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the lengths of 'begin' and 'size' must be equal to "
"the dimension of input tensor, but got the length of 'begin' "
<< begin_shape[0] << ", the length of 'size' " << size_shape[0]
<< "and the dimension of input tensor " << input_shape.size();
return false;
}
std::vector<int64_t> begin;
std::vector<int64_t> size;
if (param_dtype_ == kNumberTypeInt32) {
auto begin_ptr = reinterpret_cast<int32_t *>(inputs[1]->addr);
auto size_ptr = reinterpret_cast<int32_t *>(inputs[kSliceInputIndex2]->addr);
std::vector<int64_t> begin{begin_ptr, begin_ptr + begin_shape[0]};
std::vector<int64_t> size{size_ptr, size_ptr + size_shape[0]};
begin.assign(begin_ptr, begin_ptr + begin_shape[0]);
size.assign(size_ptr, size_ptr + size_shape[0]);
} else if (param_dtype_ == kNumberTypeInt64) {
auto begin_ptr = reinterpret_cast<int64_t *>(inputs[1]->addr);
auto size_ptr = reinterpret_cast<int64_t *>(inputs[kSliceInputIndex2]->addr);
begin.assign(begin_ptr, begin_ptr + begin_shape[0]);
size.assign(size_ptr, size_ptr + size_shape[0]);
} else {
MS_LOG(ERROR) << "Invalid dtype: " << TypeIdLabel(param_dtype_);
return false;
}
for (size_t i = 0; i < begin.size(); ++i) {
if (input_shape[i] < begin[i] + size[i]) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', slice shape should be not greater than origin shape. But in dimension i=" << i
<< ", origin shape 'input_shape[i]' is " << input_shape[i]
<< " and slice shape 'LongToSize(begin[i] + size[i])' is " << LongToSize(begin[i] + size[i]);
return false;
}
}
InitSliceParam(input_shape, begin, size);
}
auto input_addr = inputs[0]->addr;
auto output_addr = outputs[0]->addr;
if (origin_dim_size_ == kSliceTwoDims) {
auto task = [this, &input_addr, &output_addr](size_t start, size_t end) {
auto src =
@ -204,6 +241,22 @@ bool SliceCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
return true;
}
#define SLICE_CPU_REGISTER_KERNEL_ATTR(DT) \
KernelAttr().AddInputAttr(DT).AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(DT), \
KernelAttr().AddInputAttr(DT).AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(DT)
std::vector<KernelAttr> SliceCpuKernelMod::GetOpSupport() {
static const std::vector<KernelAttr> support_list = {
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeBool), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeInt8),
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeInt16), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeInt32),
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeInt64), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeUInt8),
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeUInt16), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeUInt32),
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeUInt64), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeFloat16),
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeFloat32), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeFloat64),
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeComplex64), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeComplex128)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Slice, SliceCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SLICE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SLICE_CPU_KERNEL_H_
#include <vector>
#include <memory>
@ -39,39 +39,7 @@ class SliceCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
std::vector<KernelAttr> GetOpSupport() override {
static const std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool)};
return support_list;
}
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitSliceParam(const ShapeVector &input_shape, const std::vector<int64_t> &begin,
@ -80,10 +48,12 @@ class SliceCpuKernelMod : public NativeCpuKernelMod {
int data_size_{4};
SliceParameter slice_param_;
size_t output_num_{1};
TypeId param_dtype_{kNumberTypeInt32};
std::vector<int64_t> input_shape_;
std::vector<int64_t> begin_shape_;
std::vector<int64_t> size_shape_;
bool is_got_value_{false};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SLICE_CPU_KERNEL_H_

View File

@ -28,8 +28,7 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kStridedSliceInputsNum = 1;
constexpr size_t kStridedSliceDynamicInputsNum = 4;
constexpr size_t kStridedSliceInputsNum = 4;
constexpr size_t kStridedSliceOutputsNum = 1;
using complex64 = std::complex<float>;
@ -58,10 +57,11 @@ int StridedSliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
if (ret != KRET_OK) {
return ret;
}
if (inputs.size() != kStridedSliceInputsNum && inputs.size() != kStridedSliceDynamicInputsNum) {
if (inputs.size() != kStridedSliceInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kStridedSliceInputsNum
<< " or " << kStridedSliceDynamicInputsNum << ", but got " << inputs.size();
<< ", but got " << inputs.size();
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceOutputsNum, kernel_name_);
input_shape_ = inputs[0]->GetShapeVector();
dtype_ = inputs[0]->GetDtype();
@ -76,7 +76,6 @@ int StridedSliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
InitParallelParam();
}
if (inputs.size() == kStridedSliceDynamicInputsNum) {
// for begin, end, stride are not const input
begin_shape_ = inputs[kIndex1]->GetShapeVector();
end_shape_ = inputs[kIndex2]->GetShapeVector();
@ -88,14 +87,7 @@ int StridedSliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
<< begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size()
<< ", and the dimension of 'strides': " << stride_shape_.size();
}
} else {
// for begin, end, stride are tuples
auto kernel_ptr = std::dynamic_pointer_cast<ops::StridedSlice>(base_operator);
auto begin = kernel_ptr->get_begin();
auto end = kernel_ptr->get_end();
auto stride = kernel_ptr->get_strides();
InitSliceParam(base_operator, &begin, &end, &stride);
}
return ret;
}
@ -244,16 +236,14 @@ template <typename T, typename S>
bool StridedSliceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kStridedSliceInputsNum && inputs.size() != kStridedSliceDynamicInputsNum) {
if (inputs.size() != kStridedSliceInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kStridedSliceInputsNum
<< " or " << kStridedSliceDynamicInputsNum << ", but got " << inputs.size();
<< ", but got " << inputs.size();
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceOutputsNum, kernel_name_);
auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr);
size_t input_num = inputs.size();
if (input_num == kStridedSliceDynamicInputsNum) {
// for begin, end, stride are tensors
std::vector<int64_t> begin;
std::vector<int64_t> end;
@ -271,7 +261,6 @@ bool StridedSliceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
stride.push_back(static_cast<int64_t>(strides_ptr[i]));
}
InitSliceParam(base_operator_, &begin, &end, &stride);
}
int thread_num = slice_param_.op_parameter_.thread_num_;
if (parallel_ && thread_num >= 2) {

View File

@ -17,13 +17,14 @@
#include "plugin/device/cpu/kernel/tile_cpu_kernel.h"
#include <algorithm>
#include <map>
#include <numeric>
#include <functional>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kTileInputsNum = 1;
constexpr size_t kTileDynamicInputsNum = 2;
constexpr size_t kTileInputsNum = 2;
constexpr size_t kTileOutputsNum = 1;
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
@ -84,18 +85,14 @@ void TileCpuKernelMod::TileMultipleCompute() {
bool TileCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
kernel_name_ = base_operator->name();
input_num = inputs.size();
dtype_ = inputs[kIndex0]->GetDtype();
multiples_.clear();
if (input_num == kTileInputsNum) {
std::vector<int64_t> multiples_me = GetValue<std::vector<int64_t>>(prim->GetAttr("multiples"));
(void)std::transform(multiples_me.begin(), multiples_me.end(), std::back_inserter(multiples_),
[](const int64_t &value) { return static_cast<int>(value); });
input_num_ = inputs.size();
if (input_num_ != kTileInputsNum) {
MS_LOG(EXCEPTION) << "Tile's inputs number should be " << kTileInputsNum << ", but got " << input_num_;
}
multiples_.clear();
dtype_ = inputs[kIndex0]->GetDtype();
launch_map_[kNumberTypeInt8] = &TileCpuKernelMod::LaunchKernel<int8_t>;
launch_map_[kNumberTypeInt16] = &TileCpuKernelMod::LaunchKernel<int16_t>;
launch_map_[kNumberTypeInt32] = &TileCpuKernelMod::LaunchKernel<int>;
@ -129,20 +126,20 @@ int TileCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
x_shape_ = inputs[kIndex0]->GetShapeVector();
y_shape_ = outputs[kIndex0]->GetShapeVector();
if (input_num == kTileDynamicInputsNum) {
multiple_shape = inputs[kIndex1]->GetShapeVector();
multiple_shape_ = inputs[kIndex1]->GetShapeVector();
multiple_dtype_ = inputs[kIndex1]->GetDtype();
}
return KRET_OK;
}
bool TileCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != kTileInputsNum && inputs.size() != kTileDynamicInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of input must be " << kTileInputsNum << " or "
<< kTileDynamicInputsNum << ", but got " << inputs.size();
if (inputs.size() != kTileInputsNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of input must be " << kTileInputsNum << ", but got "
<< inputs.size();
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTileOutputsNum, kernel_name_);
launch_func_(this, inputs, outputs);
@ -153,15 +150,8 @@ template <typename T>
void TileCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto y_addr = reinterpret_cast<T *>(outputs[0]->addr);
if (input_num == kTileInputsNum) {
TileMultipleCompute();
}
if (input_num == kTileDynamicInputsNum) {
multiples_.clear();
int64_t multiple_nums = 1;
for (size_t i = 0; i < multiple_shape.size(); ++i) {
multiple_nums *= multiple_shape[i];
}
auto multiple_nums = std::accumulate(multiple_shape_.begin(), multiple_shape_.end(), 1, std::multiplies<int64_t>());
if (multiple_dtype_ == kNumberTypeInt32) {
auto multiples_addr = GetDeviceAddress<int32_t>(inputs, 1);
for (size_t i = 0; i < LongToSize(multiple_nums); ++i) {
@ -174,7 +164,6 @@ void TileCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
}
}
TileMultipleCompute();
}
tile_parameter_.data_size_ = sizeof(T);
if (one_dim_tile_) {

View File

@ -112,9 +112,9 @@ class TileCpuKernelMod : public NativeCpuKernelMod {
ShapeVector x_shape_;
ShapeVector y_shape_;
ShapeVector multiple_shape;
size_t input_num;
size_t output_num;
size_t input_num_;
std::vector<int> multiples_;
ShapeVector multiple_shape_;
TypeId dtype_{kTypeUnknown};
TypeId multiple_dtype_{kTypeUnknown};

View File

@ -27,8 +27,7 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kTransposeInputsNum = 1;
constexpr size_t kDynamicPermInputNum = 2;
constexpr size_t kTransposeInputNum = 2;
constexpr size_t kTransposeOutputsNum = 1;
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
@ -86,6 +85,9 @@ int TransposeFwdCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
if (ret != KRET_OK) {
return ret;
}
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputNum, kernel_name_);
input_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
output_shape_ = outputs[kIndex0]->GetDeviceShapeAdaptively();
dtype_ = inputs[kIndex0]->GetDtype();
@ -98,23 +100,24 @@ int TransposeFwdCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
strides_[i - 1] = input_shape_[i] * strides_[i];
out_strides_[i - 1] = output_shape_[i] * out_strides_[i];
}
if (inputs.size() == kDynamicPermInputNum) {
perm_type_ = inputs[kIndex1]->GetDtype();
perm_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
return KRET_OK;
}
auto prim = std::dynamic_pointer_cast<ops::Transpose>(base_operator);
MS_ERROR_IF_NULL(prim);
perm_ = prim->get_perm();
perm_.clear();
got_perm_value_ = TryGetIntValue(inputs, kIndex1, kernel_name_, &perm_, false);
if (got_perm_value_) {
CheckPermValue();
}
return KRET_OK;
}
bool TransposeFwdCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_);
if (inputs.size() == kDynamicPermInputNum) {
if (!got_perm_value_) {
if (perm_type_ == kNumberTypeInt32) {
InitPerm<int32_t>(inputs);
} else {
@ -126,34 +129,6 @@ bool TransposeFwdCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inp
}
std::vector<std::pair<KernelAttr, TransposeFwdCpuKernelMod::TypeKernel>> TransposeFwdCpuKernelMod::launch_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&TransposeFwdCpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&TransposeFwdCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&TransposeFwdCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&TransposeFwdCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&TransposeFwdCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&TransposeFwdCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&TransposeFwdCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&TransposeFwdCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&TransposeFwdCpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&TransposeFwdCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&TransposeFwdCpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&TransposeFwdCpuKernelMod::LaunchKernel<complex64>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&TransposeFwdCpuKernelMod::LaunchKernel<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&TransposeFwdCpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
@ -165,7 +140,7 @@ std::vector<std::pair<KernelAttr, TransposeFwdCpuKernelMod::TypeKernel>> Transpo
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
&TransposeFwdCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
&TransposeFwdCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
@ -193,7 +168,7 @@ std::vector<std::pair<KernelAttr, TransposeFwdCpuKernelMod::TypeKernel>> Transpo
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
&TransposeFwdCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
&TransposeFwdCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),

View File

@ -80,6 +80,7 @@ class TransposeFwdCpuKernelMod : public NativeCpuKernelMod {
size_t data_num_{0};
std::vector<int64_t> strides_;
std::vector<int64_t> out_strides_;
bool got_perm_value_{false};
using TypeKernel =
std::function<void(TransposeFwdCpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,6 +48,25 @@ std::vector<int64_t> TransposeAxis(const int dim, FormatTransformDir dir) {
return axis;
}
ValueNodePtr CreateValueNode(const std::vector<int64_t> &transpose_perm, const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto tensor_ptr = std::make_shared<tensor::Tensor>(transpose_perm, TypeIdToType(kNumberTypeInt64));
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto value_node = std::make_shared<ValueNode>(tensor_ptr);
MS_EXCEPTION_IF_NULL(value_node);
value_node->set_abstract(tensor_ptr->ToAbstract());
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
if (kernel_graph != nullptr) {
value_node = kernel_graph->NewValueNode(value_node);
kernel_graph->AddValueNodeToGraph(value_node);
} else {
value_node = MakeValueNode(value_node);
}
return value_node;
}
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
int used_node_index, const std::vector<int64_t> &transpose_perm) {
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
@ -58,13 +77,13 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
auto transpose_prim = std::make_shared<Primitive>(primitive_ptr->name());
MS_EXCEPTION_IF_NULL(transpose_prim);
// 2.Set the input of transpose.
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
auto perm_value_node = CreateValueNode(transpose_perm, graph);
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node, perm_value_node};
auto transpose_op = graph->NewCNode(transpose_input);
// 3.Set the output info of transpose.
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, IntToSize(used_node_index))};
auto transpose_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(used_node, IntToSize(used_node_index))};
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, transpose_shape, transpose_op.get());
common::AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
// 4. Set the new edge of transpose op.
FuncGraphManagerPtr manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -78,8 +97,8 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
auto input_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
auto output_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetInputsDeviceType({input_type});
builder.SetInputsFormat({input_format, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({input_type, kNumberTypeInt64});
builder.SetOutputsFormat({output_format});
builder.SetOutputsDeviceType({output_type});
builder.SetKernelType(UNKNOWN_KERNEL_TYPE);

View File

@ -24,14 +24,6 @@ namespace mindspore::opt {
RER_CPU_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kFillOpName, 0);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceAllOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceMaxOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceMeanOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceMinOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceSumOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceProdOpName, 1);
RER_CPU_DYNAMIC_CONST_TO_ATTR(kTransposeOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kCastOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kConv2DBackpropInputOpName, 2);
@ -48,19 +40,7 @@ RER_CPU_STATIC_CONST_TO_ATTR(kCumprodOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kFillOpName, 0);
RER_CPU_STATIC_CONST_TO_ATTR(kMeanGradOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kParallelResizeBilinearGradOpName, 2);
RER_CPU_STATIC_CONST_TO_ATTR(kReduceAllOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kReduceMaxOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kReduceMeanOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kReduceMinOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kReduceProdOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kReduceSumOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kReshapeOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kSliceOpName, 1, 2);
RER_CPU_STATIC_CONST_TO_ATTR(kSparseApplyAdagradOpName, 2);
RER_CPU_STATIC_CONST_TO_ATTR(kStridedSliceOpName, 1, 2, 3);
RER_CPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);
RER_CPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
} // namespace mindspore::opt
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_OPTIMIZER_REG_CPU_CONST_INPUT_TO_ATTR_H_

View File

@ -48,72 +48,51 @@ const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = {
{"ReduceProd", CUDNN_REDUCE_TENSOR_MUL},
};
#define STATIC_REGISTER(INPUTX, T) \
KernelAttr().AddInputAttr(INPUTX).AddOutputAttr(INPUTX), &ArrayReduceGpuKernelMod::LaunchKernel<T>
#define DYN_REGISTER(INPUTX, AXIS, T) \
#define REDUCE_REGISTER(INPUTX, AXIS, T) \
KernelAttr().AddInputAttr(INPUTX).AddInputAttr(AXIS).AddOutputAttr(INPUTX), &ArrayReduceGpuKernelMod::LaunchKernel<T>
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::all_any_list_ = {
{STATIC_REGISTER(kNumberTypeBool, bool)},
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}};
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}};
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::prod_list_ = {
{STATIC_REGISTER(kNumberTypeFloat64, double)},
{STATIC_REGISTER(kNumberTypeFloat32, float)},
{STATIC_REGISTER(kNumberTypeFloat16, half)},
{STATIC_REGISTER(kNumberTypeInt8, int8_t)},
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
{DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
{DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
};
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::sum_list_ = {
{STATIC_REGISTER(kNumberTypeFloat64, double)},
{STATIC_REGISTER(kNumberTypeFloat32, float)},
{STATIC_REGISTER(kNumberTypeFloat16, half)},
{STATIC_REGISTER(kNumberTypeBool, bool)},
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)},
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
};
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::max_min_mean_list_ = {
{STATIC_REGISTER(kNumberTypeFloat64, double)},
{STATIC_REGISTER(kNumberTypeFloat32, float)},
{STATIC_REGISTER(kNumberTypeFloat16, half)},
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
};
std::map<std::string, std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>>>
ArrayReduceGpuKernelMod::kernel_attr_list_ = {
@ -171,6 +150,10 @@ bool ArrayReduceGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
data_type_ = GetCudnnDataType(type_name);
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
keep_dims_ = kernel_ptr->get_keep_dims();
skip_mode_ = kernel_ptr->get_skip_mode();
InitResource();
return true;
}
@ -202,27 +185,17 @@ int ArrayReduceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
return ret;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
auto inputA_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
std::vector<int64_t> attr_axis;
if (inputs.size() == kDynamicAxisInputNum) {
if (AnfAlgo::IsDynamicShapeSkipExecute(kernel_name_, inputs[kIndex1]->GetShapeVector())) {
if (!TryGetIntValue(inputs, kIndex1, kernel_name_, &attr_axis)) {
MS_LOG(EXCEPTION) << "For " << kernel_name_ << " can't get axis input! ";
}
if (AnfAlgo::IsDynamicShapeSkipExecute(skip_mode_, inputs[kIndex1]->GetShapeVector())) {
need_skip_execute_ = true;
// As size of input_size_list_ is equal to size of inputs, input_size_list_[0] is safe.
input_size_ = input_size_list_[0];
return KRET_OK;
}
if (!TryGetIntValue(inputs, kIndex1, kernel_name_, &attr_axis)) {
InitCudnnResource();
return KRET_OK;
}
} else {
if (kernel_ptr->HasAttr(kAttrAxis)) {
attr_axis = kernel_ptr->get_axis();
}
}
keep_dims_ = kernel_ptr->get_keep_dims();
int input_dim_length = SizeToInt(inputA_shape.size());
axis_.clear();

View File

@ -68,6 +68,7 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
outputC_descriptor_ = nullptr;
need_skip_execute_ = false;
keep_dims_ = false;
skip_mode_ = false;
all_match_ = false;
is_null_input_ = false;
complex_op_type = 0;
@ -120,6 +121,7 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
std::vector<int> axis_;
bool keep_dims_;
bool skip_mode_;
bool need_skip_execute_;
bool all_match_;
bool is_null_input_;

View File

@ -31,9 +31,6 @@ bool OneHotGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::v
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input num should be 3 or 4, but get: " << inputs.size();
return false;
}
if (inputs.size() == max_input_num) {
is_dynamic_shape_ = true;
}
constexpr size_t output_num = 1;
kernel_name_ = base_operator->GetPrim()->name();
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
@ -94,10 +91,8 @@ bool OneHotGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, con
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
size_t on_value_idx = 1;
size_t off_value_idx = 2;
if (is_dynamic_shape_) {
on_value_idx++;
off_value_idx++;
}
const S *indices = GetDeviceAddress<S>(inputs, 0);
const T *on_value = GetDeviceAddress<T>(inputs, on_value_idx);
const T *off_value = GetDeviceAddress<T>(inputs, off_value_idx);

View File

@ -55,7 +55,6 @@ class OneHotGpuKernelMod : public NativeGpuKernelMod {
static std::vector<std::pair<KernelAttr, OneHotLaunchFunc>> func_list_;
OneHotLaunchFunc kernel_func_;
bool is_dynamic_shape_ = false;
size_t depth_{0};
size_t left_dim_size_{1};
size_t right_dim_size_{1};

View File

@ -26,20 +26,7 @@ std::unique_ptr<cukernel::GpuKernelHelperBase> CreateSliceKernelPtr(const std::s
using SlicePtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, SlicePtrCreatorFunc>> kernel_attr = {
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateSliceKernelPtr<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateSliceKernelPtr<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateSliceKernelPtr<half>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateSliceKernelPtr<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateSliceKernelPtr<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateSliceKernelPtr<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateSliceKernelPtr<char>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CreateSliceKernelPtr<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CreateSliceKernelPtr<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CreateSliceKernelPtr<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateSliceKernelPtr<uchar>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateSliceKernelPtr<bool>},
{KernelAttr()
const std::vector<std::pair<KernelAttr, SlicePtrCreatorFunc>> kernel_attr = {{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)

View File

@ -34,13 +34,7 @@ class StridedSliceGpuCommon {
StridedSliceGpuCommon() : null_output_(false) {}
~StridedSliceGpuCommon() = default;
void CollectInfo(const BaseOperatorPtr &base_operator, bool is_dynamic_attr_ = false) {
if (!is_dynamic_attr_) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::StridedSlice>(base_operator);
begin_ = kernel_ptr->get_begin();
end_ = kernel_ptr->get_end();
strides_ = kernel_ptr->get_strides();
}
void CollectInfo(const BaseOperatorPtr &base_operator) {
auto shape_tmp = Convert2Long(input_shape_);
FillEmptyDims(base_operator, &begin_, &end_, &strides_, &shape_tmp);
input_shape_ = Convert2SizeT(shape_tmp);

View File

@ -22,7 +22,6 @@
namespace mindspore {
namespace kernel {
constexpr size_t DynamicInputNum = 4;
template <typename T>
using Complex = mindspore::utils::Complex<T>;
@ -57,10 +56,6 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
return ret;
}
if (inputs.size() == DynamicInputNum) {
is_dynamic_attr_ = true;
}
auto shape_signed = inputs[0]->GetShapeVector();
input_shape_ = Convert2SizeTClipNeg(shape_signed);
null_output_ = CHECK_SHAPE_NULL(input_shape_, kernel_name_, "input");
@ -71,12 +66,11 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << MAX_DIMS
<< ", but got " << input_shape_.size();
}
if (is_dynamic_attr_) {
TryGetIntValue(inputs, kBeginIndex_, kernel_name_, &begin_);
TryGetIntValue(inputs, kEndIndex_, kernel_name_, &end_);
TryGetIntValue(inputs, kStrideIndex_, kernel_name_, &strides_);
}
CollectInfo(base_operator, is_dynamic_attr_);
CollectInfo(base_operator);
return ret;
}
@ -94,21 +88,6 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
&StridedSliceGpuKernelMod::LaunchKernel<TYPE_1, TYPE_2>
std::vector<std::pair<KernelAttr, StridedSliceGpuKernelMod::StridedSliceFunc>> StridedSliceGpuKernelMod::func_list_ = {
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat64, double)},
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat32, float)},
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat16, half)},
{STRIDEDSLICE_GPU_REG(kNumberTypeInt64, int64_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeInt32, int32_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeInt16, int16_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeInt8, int8_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt64, uint64_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt32, uint32_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt16, uint16_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt8, uint8_t)},
{STRIDEDSLICE_GPU_REG(kNumberTypeBool, bool)},
{STRIDEDSLICE_GPU_REG(kNumberTypeComplex64, Complex<float>)},
{STRIDEDSLICE_GPU_REG(kNumberTypeComplex128, Complex<double>)},
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat64, kNumberTypeInt64, double, int64_t)},
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float, int64_t)},
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat16, kNumberTypeInt64, half, int64_t)},

View File

@ -59,8 +59,6 @@ class StridedSliceGpuKernelMod : public NativeGpuKernelMod, public StridedSliceG
StridedSliceFunc kernel_func_;
bool is_null_input_{false};
bool is_dynamic_attr_{false};
bool get_dynamic_attr_value_{false};
static constexpr size_t kBeginIndex_{1};
static constexpr size_t kEndIndex_{2};
static constexpr size_t kStrideIndex_{3};

View File

@ -18,38 +18,6 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
TensorCopySlices,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
TensorCopySlicesGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(
TensorCopySlices,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorCopySlicesGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
TensorCopySlices,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TensorCopySlicesGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(
TensorCopySlices,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TensorCopySlicesGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(
TensorCopySlices,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TensorCopySlicesGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(
TensorCopySlices,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
TensorCopySlicesGpuKernelMod, char)
MS_REG_GPU_KERNEL_ONE(
TensorCopySlices,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
TensorCopySlicesGpuKernelMod, uchar)
MS_REG_GPU_KERNEL_ONE(
TensorCopySlices,
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
TensorCopySlicesGpuKernelMod, bool)
MS_REG_GPU_KERNEL_TWO(TensorCopySlices,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)

View File

@ -70,9 +70,6 @@ class TensorCopySlicesGpuKernelMod : public NativeGpuKernelMod {
return ret;
}
ResetResource();
if (inputs.size() == DynamicInputNum) {
is_dynamic_attr_ = true;
}
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), OutputNum, kernel_name_);
auto shape_signed = inputs.at(kIndex0)->GetShapeVector();
input_shape_ = Convert2SizeTClipNeg(shape_signed);
@ -87,17 +84,9 @@ class TensorCopySlicesGpuKernelMod : public NativeGpuKernelMod {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << kMaxDims
<< ", but got " << input_shape_.size();
}
if (is_dynamic_attr_) {
TryGetIntValue(inputs, kBeginIndex_, kernel_name_, &begin_, false);
TryGetIntValue(inputs, kEndIndex_, kernel_name_, &end_, false);
TryGetIntValue(inputs, kStrideIndex_, kernel_name_, &strides_, false);
} else {
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
begin_ = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrBegin));
end_ = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrEnd));
strides_ = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrStrides));
}
if (begin_.size() > input_shape_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the size of 'begin' cannot be greater than the dimension of input, but got the "
@ -118,7 +107,6 @@ class TensorCopySlicesGpuKernelMod : public NativeGpuKernelMod {
output_size_ = 1;
update_size_ = 1;
is_null_input_ = false;
is_dynamic_attr_ = false;
input_shape_.clear();
output_shape_.clear();
update_shape_.clear();
@ -224,7 +212,7 @@ class TensorCopySlicesGpuKernelMod : public NativeGpuKernelMod {
size_t output_size_;
inline static size_t kMaxDims = 8;
bool is_null_input_;
bool is_dynamic_attr_{false};
std::string kernel_name_;
static constexpr size_t kBeginIndex_{2};
static constexpr size_t kEndIndex_{3};
static constexpr size_t kStrideIndex_{4};

View File

@ -22,8 +22,7 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kStaticInputNum = 1;
constexpr size_t kDynInputNum = 2;
constexpr size_t kInputNum = 2;
constexpr size_t kTileOutputsNum = 1;
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
@ -35,12 +34,8 @@ bool TileGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec
MS_EXCEPTION_IF_NULL(prim);
kernel_name_ = base_operator->name();
size_t input_num = inputs.size();
if (input_num == kStaticInputNum) {
is_dynamic_case_ = false;
} else if (input_num == kDynInputNum) {
is_dynamic_case_ = true;
} else {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 1 or 2, but got " << input_num;
if (input_num != kInputNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << input_num;
return false;
}
@ -83,16 +78,7 @@ int TileGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
}
shape_size_ = output_shape_.size();
output_size_ = SizeOf(output_shape_);
if (!is_dynamic_case_) {
const std::string kAttrMultiples = "multiples";
auto multi_attr = base_operator->GetPrim()->GetAttr(kAttrMultiples);
if (multi_attr == nullptr) {
return KRET_RESIZE_FAILED;
}
multiples = GetValue<std::vector<int64_t>>(multi_attr);
} else {
TryGetIntValue(inputs, kIndex1, kernel_name_, &multiples);
}
int64_t filling_value = static_cast<int64_t>(multiples.size()) - static_cast<int64_t>(input_shape_.size());
(void)input_shape_.insert(input_shape_.begin(), filling_value, kIndex1);
workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t));
@ -127,24 +113,6 @@ template <typename T>
using Complex = mindspore::utils::Complex<T>;
std::vector<std::pair<KernelAttr, TileGpuKernelMod::TileLaunchFunc>> TileGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&TileGpuKernelMod::LaunchKernel<Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&TileGpuKernelMod::LaunchKernel<Complex<double>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&TileGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&TileGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&TileGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&TileGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), &TileGpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&TileGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), &TileGpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), &TileGpuKernelMod::LaunchKernel<bool>},
// For dynamic shape case:
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
&TileGpuKernelMod::LaunchKernel<Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),

View File

@ -50,7 +50,6 @@ class TileGpuKernelMod : public NativeGpuKernelMod {
output_size_ = 1;
shape_size_ = 1;
is_null_input_ = false;
is_dynamic_case_ = false;
input_shape_.clear();
output_shape_.clear();
input_size_list_.clear();
@ -68,7 +67,6 @@ class TileGpuKernelMod : public NativeGpuKernelMod {
ShapeVector input_shape_;
ShapeVector output_shape_;
bool is_null_input_;
bool is_dynamic_case_;
std::vector<int64_t> multiples;
using TileLaunchFunc =
std::function<bool(TileGpuKernelMod *, const std::vector<kernel::AddressPtr> &,

View File

@ -24,7 +24,7 @@ namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
constexpr size_t kDynamicPermInputNum = 2;
constexpr size_t kPermInputNum = 2;
constexpr size_t kDimSize4 = 4;
constexpr size_t kAxisZero = 0;
constexpr size_t kAxis1st = 1;
@ -35,10 +35,7 @@ constexpr size_t kAxisIndex1st = 1;
constexpr size_t kAxisIndex2nd = 2;
constexpr size_t kAxisIndex3rd = 3;
#define STATIC_REGISTER(INPUTX, OUTPUT, T) \
{ KernelAttr().AddInputAttr(INPUTX).AddOutputAttr(OUTPUT), &TransposeGpuKernelMod::LaunchKernel<T> }
#define DYN_REGISTER(INPUTX, PERM, OUTPUT, T) \
#define OP_REGISTER(INPUTX, PERM, OUTPUT, T) \
{ \
KernelAttr().AddInputAttr(INPUTX).AddInputAttr(PERM).AddOutputAttr(OUTPUT), \
&TransposeGpuKernelMod::LaunchKernel<T> \
@ -47,48 +44,34 @@ constexpr size_t kAxisIndex3rd = 3;
const std::vector<std::pair<KernelAttr, TransposeGpuKernelMod::KernelRunFunc>> &TransposeGpuKernelMod::GetFuncList()
const {
static const std::vector<std::pair<KernelAttr, TransposeGpuKernelMod::KernelRunFunc>> func_list = {
STATIC_REGISTER(kNumberTypeComplex64, kNumberTypeComplex64, Complex<float>),
STATIC_REGISTER(kNumberTypeComplex128, kNumberTypeComplex128, Complex<double>),
STATIC_REGISTER(kNumberTypeBool, kNumberTypeBool, bool),
STATIC_REGISTER(kNumberTypeFloat64, kNumberTypeFloat64, double),
STATIC_REGISTER(kNumberTypeFloat32, kNumberTypeFloat32, float),
STATIC_REGISTER(kNumberTypeFloat16, kNumberTypeFloat16, half),
STATIC_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t),
STATIC_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t),
STATIC_REGISTER(kNumberTypeInt16, kNumberTypeInt16, int16_t),
STATIC_REGISTER(kNumberTypeInt8, kNumberTypeInt8, int8_t),
STATIC_REGISTER(kNumberTypeUInt8, kNumberTypeUInt8, uint8_t),
STATIC_REGISTER(kNumberTypeUInt16, kNumberTypeUInt16, uint16_t),
STATIC_REGISTER(kNumberTypeUInt32, kNumberTypeUInt32, uint32_t),
STATIC_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, uint64_t),
DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, kNumberTypeComplex64, Complex<float>),
DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, kNumberTypeComplex128, Complex<double>),
DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, kNumberTypeBool, bool),
DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat64, double),
DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, float),
DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, kNumberTypeFloat16, half),
DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t),
DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t),
DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt16, int16_t),
DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt8, int8_t),
DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, uint8_t),
DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeUInt16, uint16_t),
DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeUInt32, uint32_t),
DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeUInt64, uint64_t),
DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, kNumberTypeComplex64, Complex<float>),
DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, kNumberTypeComplex128, Complex<double>),
DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, kNumberTypeBool, bool),
DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeFloat64, double),
DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeFloat32, float),
DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, kNumberTypeFloat16, half),
DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, int32_t),
DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt16, int16_t),
DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt8, int8_t),
DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, uint8_t),
DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, uint16_t),
DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, uint32_t),
DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t),
OP_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, kNumberTypeComplex64, Complex<float>),
OP_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, kNumberTypeComplex128, Complex<double>),
OP_REGISTER(kNumberTypeBool, kNumberTypeInt32, kNumberTypeBool, bool),
OP_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat64, double),
OP_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, float),
OP_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, kNumberTypeFloat16, half),
OP_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t),
OP_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t),
OP_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt16, int16_t),
OP_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt8, int8_t),
OP_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, uint8_t),
OP_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeUInt16, uint16_t),
OP_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeUInt32, uint32_t),
OP_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeUInt64, uint64_t),
OP_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, kNumberTypeComplex64, Complex<float>),
OP_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, kNumberTypeComplex128, Complex<double>),
OP_REGISTER(kNumberTypeBool, kNumberTypeInt64, kNumberTypeBool, bool),
OP_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeFloat64, double),
OP_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeFloat32, float),
OP_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, kNumberTypeFloat16, half),
OP_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
OP_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, int32_t),
OP_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt16, int16_t),
OP_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt8, int8_t),
OP_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, uint8_t),
OP_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, uint16_t),
OP_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, uint32_t),
OP_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t),
};
return func_list;
}
@ -102,10 +85,6 @@ bool TransposeGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
if (is_dynamic_perm_ && !get_dynamic_perm_value_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to get value of the dynamic perm!";
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
@ -155,26 +134,13 @@ bool TransposeGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
size_t input_num = inputs.size();
size_t output_num = outputs.size();
kernel_name_ = base_operator->name();
if (input_num != 1 && input_num != kDynamicPermInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1 or " << kDynamicPermInputNum
<< ", but got " << input_num;
if (input_num != kPermInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kPermInputNum << ", but got "
<< input_num;
}
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
}
if (input_num == kDynamicPermInputNum) {
is_dynamic_perm_ = true;
return true;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::Transpose>(base_operator);
if (!kernel_ptr) {
MS_LOG(ERROR) << "For primitive[Transpose], cast op from BaseOperator to Transpose failed.";
return false;
}
std::vector<int64_t> perm;
perm = kernel_ptr->get_perm();
GetPermValue(perm);
return true;
}

View File

@ -120,16 +120,9 @@ std::vector<KernelAttr> MemcpyGpuKernelMod::GetOpSupport() {
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128)};
std::vector<KernelAttr> reshape_valid_types;
reshape_valid_types.insert(reshape_valid_types.end(), common_valid_types_with_single_input.begin(),
common_valid_types_with_single_input.end());
reshape_valid_types.insert(reshape_valid_types.end(), common_valid_types_with_double_input.begin(),
common_valid_types_with_double_input.end());
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
{kReshape, reshape_valid_types},
{kFlatten, common_valid_types_with_single_input},
{kFlattenGrad, common_valid_types_with_single_input},
{kExpandDims, common_valid_types_with_double_input},
{kReshape, common_valid_types_with_double_input}, {kFlatten, common_valid_types_with_single_input},
{kFlattenGrad, common_valid_types_with_single_input}, {kExpandDims, common_valid_types_with_double_input},
{kSqueeze, common_valid_types_with_single_input},
};

View File

@ -37,6 +37,8 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
(void)node_inputs.insert(node_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
CNodePtr new_cnode = func_graph->NewCNode(node_inputs);
MS_EXCEPTION_IF_NULL(new_cnode);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto predict_input = cnode->inputs()[1];
auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)};
auto new_node_shape = {common::AnfAlgo::GetOutputDetailShape(predict_input, 0)};
@ -46,10 +48,20 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
string reduction = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrReduction);
MS_LOG(INFO) << "Create reduce node for BCEWithLogitsLoss, reduction attr is: " << reduction;
std::vector<AnfNodePtr> reduce_inputs;
std::vector<int64_t> axis_shp = {0};
auto axis_tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), axis_shp);
MS_EXCEPTION_IF_NULL(axis_tensor);
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, kInt64};
axis_tensor->set_device_info(device_info);
ValueNodePtr axis_node = std::make_shared<ValueNode>(axis_tensor);
MS_EXCEPTION_IF_NULL(axis_node);
axis_node->set_abstract(axis_tensor->ToAbstract());
axis_node = kernel_graph->NewValueNode(axis_node);
kernel_graph->AddValueNode(axis_node);
if (reduction == "sum") {
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode};
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode, axis_node};
} else if (reduction == "mean") {
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode};
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode, axis_node};
} else {
MS_LOG(INFO) << "Reduction is none, no optimization on current BCEWithLogitsLoss.";
return nullptr;
@ -59,7 +71,6 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
auto type = common::AnfAlgo::GetOutputInferDataType(node, 0);
auto shape = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
common::AnfAlgo::SetOutputTypeAndDetailShape({type}, shape, reduce_node.get());
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{}), reduce_node);
common::AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node);
reduce_node->set_scope(cnode->scope());
return reduce_node;

View File

@ -31,6 +31,23 @@ namespace opt {
namespace {
const int kFakeTransposeShapeOneNum = 2;
AnfNodePtr ConvertValueToTensor(const KernelGraphPtr &kernel_graph, const ValueNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node);
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
tensor::TensorPtr tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
MS_EXCEPTION_IF_NULL(tensor_input);
tensor_input->set_abstract(tensor_ptr->ToAbstract());
tensor_input = kernel_graph->NewValueNode(tensor_input);
kernel_graph->AddValueNodeToGraph(tensor_input);
tensor_input->set_scope(input_node->scope());
return tensor_input;
}
std::vector<int64_t> TransposeAxis(const std::string &src_format, const std::string &dst_format) {
if ((src_format == kOpFormat_NCHW) && (dst_format == kOpFormat_NHWC)) {
return {0, 2, 3, 1};
@ -66,8 +83,8 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
auto input_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
auto output_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetInputsDeviceType({input_type});
builder.SetInputsFormat({input_format, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({input_type, kNumberTypeInt64});
builder.SetOutputsFormat({output_format});
builder.SetOutputsDeviceType({output_type});
builder.SetKernelType(UNKNOWN_KERNEL_TYPE);
@ -81,6 +98,8 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(used_node);
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
<< ", index: " << used_node_index;
@ -97,18 +116,23 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
MS_EXCEPTION_IF_NULL(transpose_prim);
// 2.Set the input of transpose.
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
if (is_fake) {
auto reshape_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index);
auto shape_node = NewValueNode(MakeValue<std::vector<int64_t>>(reshape_shape));
auto shape_tensor = ConvertValueToTensor(kernel_graph, shape_node);
transpose_input.push_back(shape_tensor);
} else {
auto perm_node = NewValueNode(MakeValue<std::vector<int64_t>>(transpose_perm));
auto perm_tensor = ConvertValueToTensor(kernel_graph, perm_node);
transpose_input.push_back(perm_tensor);
}
auto transpose_op = graph->NewCNode(transpose_input);
MS_EXCEPTION_IF_NULL(transpose_op);
// 3.Set the output info of transpose.
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
auto transpose_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index);
auto base_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(used_node, used_node_index);
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, {base_shape}, transpose_op.get());
if (is_fake) {
common::AnfAlgo::SetNodeAttr("shape", MakeValue(transpose_shape), transpose_op);
} else {
common::AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
}
// 4. Set the new edge of transpose op.
FuncGraphManagerPtr manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);

View File

@ -22,6 +22,7 @@
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "include/common/utils/convert_utils.h"
#include "ir/primitive.h"
#include "include/common/utils/utils.h"
#include "plugin/device/gpu/hal/device/kernel_info_setter.h"
@ -124,6 +125,23 @@ void CalSplitAttrs(SplitvNodeInfo *splitv_node_info) {
splitv_node_info->num_split = num_split;
}
AnfNodePtr ConvertValueToTensor(const KernelGraphPtr &kernel_graph, const ValueNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node);
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
tensor::TensorPtr tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
MS_EXCEPTION_IF_NULL(tensor_ptr);
auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
MS_EXCEPTION_IF_NULL(tensor_input);
tensor_input->set_abstract(tensor_ptr->ToAbstract());
tensor_input = kernel_graph->NewValueNode(tensor_input);
kernel_graph->AddValueNodeToGraph(tensor_input);
tensor_input->set_scope(input_node->scope());
return tensor_input;
}
CNodePtr CreateSliceNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &slice_input,
const SliceNodeInfo slice_node_info, const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
@ -131,20 +149,13 @@ CNodePtr CreateSliceNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr
MS_LOG(EXCEPTION) << "The input is empty, can not create slice node.";
return nullptr;
}
auto slice = pass.NewCNode(slice_input, graph);
MS_EXCEPTION_IF_NULL(slice);
ShapeVector slice_shape(slice_node_info.base_shape);
std::vector<int64_t> begins(slice_shape.size(), 0);
slice_shape[slice_node_info.slice_dim] = slice_node_info.slice_size;
begins[slice_node_info.slice_dim] = slice_node_info.slice_begin;
std::vector<ShapeVector> shapes = {slice_shape};
std::vector<TypeId> dtypes(1, slice_node_info.input_dtype);
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, slice.get());
common::AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(begins), slice);
common::AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(slice_shape), slice);
return slice;
}
@ -152,6 +163,8 @@ CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const AnfNodePtr &split_inpu
const PatternProcessPass &pass) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(splitv_node_info);
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
CalSplitAttrs(splitv_node_info);
@ -168,7 +181,18 @@ CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const AnfNodePtr &split_inpu
for (int64_t idx = 0; idx < splitv_node_info->num_split; ++idx) {
slice_node_info.slice_size = splitv_node_info->size_splits[idx];
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), split_input};
ShapeVector slice_shape(slice_node_info.base_shape);
slice_shape[slice_node_info.slice_dim] = slice_node_info.slice_size;
// begins
std::vector<int64_t> begins(slice_shape.size(), 0);
begins[slice_node_info.slice_dim] = slice_node_info.slice_begin;
auto beigns_node = NewValueNode(MakeValue<std::vector<int64_t>>(begins));
auto beigns_tensor = ConvertValueToTensor(kernel_graph, beigns_node);
// slice_shape
auto slice_shape_node = NewValueNode(MakeValue<std::vector<int64_t>>(slice_shape));
auto slice_shape_tensor = ConvertValueToTensor(kernel_graph, slice_shape_node);
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), split_input,
beigns_tensor, slice_shape_tensor};
slice = CreateSliceNode(graph, slice_inputs, slice_node_info, pass);
MS_EXCEPTION_IF_NULL(slice);
AddNewOutputs(graph, slice, 1, &make_tuple_inputs);

View File

@ -24,16 +24,6 @@ namespace mindspore::opt {
RER_GPU_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kFillOpName, 0);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceAllOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceMaxOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceMeanOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceMinOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceSumOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceProdOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReshapeOpName, 1);
RER_GPU_DYNAMIC_CONST_TO_ATTR(kTransposeOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kCastOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kConv2DBackpropFilterOpName, 2);
RER_GPU_STATIC_CONST_TO_ATTR(kConv2DBackpropInputOpName, 2);
@ -49,24 +39,11 @@ RER_GPU_STATIC_CONST_TO_ATTR(kCSRReduceSumOpName, 3, 4);
RER_GPU_STATIC_CONST_TO_ATTR(kCumprodOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kFillOpName, 0);
RER_GPU_STATIC_CONST_TO_ATTR(kMeanGradOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kOneHotOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kParallelResizeBilinearGradOpName, 2);
RER_GPU_STATIC_CONST_TO_ATTR(kReduceAllOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kReduceMaxOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kReduceMeanOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kReduceMinOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kReduceProdOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kReduceSumOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kReshapeOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kSpaceToBatchOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kSparseApplyAdagradOpName, 2);
RER_GPU_STATIC_CONST_TO_ATTR(kStridedSliceAssignOpName, 1, 2, 3);
RER_GPU_STATIC_CONST_TO_ATTR(kStridedSliceOpName, 1, 2, 3);
RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
RER_GPU_STATIC_CONST_TO_ATTR(kSliceOpName, 1, 2);
} // namespace mindspore::opt
#endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_

View File

@ -26,9 +26,13 @@ namespace mindspore {
namespace opt {
const BaseRef RemoveFormatTransformPair::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
VarPtr Z = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
VectorRef transpose1 = VectorRef({prim::kPrimTranspose, X});
VectorRef transpose2 = VectorRef({prim::kPrimTranspose, transpose1});
MS_EXCEPTION_IF_NULL(Y);
MS_EXCEPTION_IF_NULL(Z);
VectorRef transpose1 = VectorRef({prim::kPrimTranspose, X, Y});
VectorRef transpose2 = VectorRef({prim::kPrimTranspose, transpose1, Z});
return transpose2;
}

View File

@ -27,8 +27,10 @@ namespace mindspore {
namespace opt {
const BaseRef RemoveRedundantFormatTransform::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Y = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
VectorRef transpose = VectorRef({prim::kPrimTranspose, X});
MS_EXCEPTION_IF_NULL(Y);
VectorRef transpose = VectorRef({prim::kPrimTranspose, X, Y});
return transpose;
}
@ -49,8 +51,9 @@ const AnfNodePtr RemoveRedundantFormatTransform::Process(const FuncGraphPtr &gra
break;
}
}
auto first_transpose_perm = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(first_transpose, "perm");
auto node_perm = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "perm");
const int64_t perm_param_idx = 1;
auto first_transpose_perm = AnfAlgo::GetInputDeviceShape(first_transpose, perm_param_idx);
auto node_perm = AnfAlgo::GetInputDeviceShape(node, perm_param_idx);
if ((first_transpose != node) && (first_transpose_perm == node_perm)) {
return first_transpose;
}

View File

@ -50,7 +50,7 @@ OperatorRegister &OperatorRegister::GetInstance() {
return instance;
}
std::map<std::string, OperatorDefineFunc> OperatorRegister::GetOperatorMap() { return operator_fns_; }
const std::map<std::string, OperatorDefineFunc> &OperatorRegister::GetOperatorMap() { return operator_fns_; }
void OperatorRegister::SetOperatorMap(const std::string &kname, const OperatorDefineFunc &fn) {
operator_fns_[kname] = fn;
}

View File

@ -63,7 +63,7 @@ class MIND_API OperatorRegister {
static OperatorRegister &GetInstance();
std::map<std::string, OperatorDefineFunc> GetOperatorMap();
const std::map<std::string, OperatorDefineFunc> &GetOperatorMap();
void SetOperatorMap(const std::string &kname, const OperatorDefineFunc &fn);

View File

@ -116,6 +116,7 @@ constexpr auto kIoFormat = "io_format";
constexpr auto kIsScale = "is_scale";
constexpr auto kIsTraining = "is_training";
constexpr auto kKeepDims = "keep_dims";
constexpr auto kSkipMode = "skip_mode";
constexpr auto kKeepProb = "keep_prob";
constexpr auto kPadDimSize = "pad_dim_size";
constexpr auto kUnbiased = "unbiased";

View File

@ -174,11 +174,17 @@ bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_ar
input_args[kInputIndex1]->isa<abstract::AbstractTuple>() ||
input_args[kInputIndex1]->isa<abstract::AbstractList>()) {
*axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", input_value, op_name);
if (axis_value->empty()) {
*axis_shape_v = 0;
}
} else if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
(void)CheckAndConvertUtils::CheckTensorTypeValid("axis", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64},
op_name);
if (input_value->isa<tensor::Tensor>()) {
*axis_value = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, op_name);
if (axis_value->empty()) {
*axis_shape_v = 0;
}
} else {
is_dynamic = true;
auto axis_shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 1);
@ -199,6 +205,14 @@ bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_ar
return is_dynamic;
}
bool IsDynamicShapeSkipExecute(const bool skip_mode, const ShapeVector &axes_shape) {
// Skip run ReduceSum when axis is a Empty Tensor
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
return true;
}
return false;
}
abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args,
const std::string &prim_name) {
@ -206,6 +220,12 @@ abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
MS_EXCEPTION_IF_NULL(shape_ptr);
auto x_shape = shape_ptr->shape();
bool skip_mode = false;
if (primitive->HasAttr(kSkipMode)) {
auto skip_mode_value_ptr = primitive->GetAttr(kSkipMode);
MS_EXCEPTION_IF_NULL(skip_mode_value_ptr);
skip_mode = GetValue<bool>(skip_mode_value_ptr);
}
auto keep_dimis_value_ptr = primitive->GetAttr(kKeepDims);
MS_EXCEPTION_IF_NULL(keep_dimis_value_ptr);
if (!keep_dimis_value_ptr->isa<BoolImm>()) {
@ -213,8 +233,11 @@ abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
}
bool keep_dims = GetValue<bool>(keep_dimis_value_ptr);
std::vector<int64_t> axis_value;
int64_t axis_shape = 0;
int64_t axis_shape = 1;
bool axis_is_dynamic = CheckAndGetAxisValue(input_args, &axis_value, &axis_shape, primitive);
if (IsDynamicShapeSkipExecute(skip_mode, {axis_shape})) {
return std::make_shared<abstract::Shape>(x_shape);
}
ShapeVector out_shape = {};
constexpr int dynamic_rank_value = -2;
if (IsDynamicRank(x_shape)) {

View File

@ -28,23 +28,13 @@ void Reduce::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims
bool Reduce::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }
void Reduce::Init(const bool keep_dims) { this->set_keep_dims(keep_dims); }
void Reduce::set_skip_mode(const bool skip_mode) { (void)this->AddAttr(kSkipMode, api::MakeValue(skip_mode)); }
void Reduce::set_axis(const std::vector<int64_t> &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
bool Reduce::get_skip_mode() const { return GetValue<bool>(GetAttr(kSkipMode)); }
std::vector<int64_t> Reduce::get_axis() {
PrimitivePtr prim = this->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
std::vector<int64_t> axis = {};
if (prim->HasAttr(kAttrAxis)) {
auto value_ptr = prim->GetAttr(kAttrAxis);
if (value_ptr->isa<tensor::Tensor>()) {
axis = CheckAndConvertUtils::CheckTensorIntValue(kAttrAxis, value_ptr, prim->name());
} else {
axis = CheckAndConvertUtils::CheckIntOrTupleInt(kAttrAxis, value_ptr, prim->name());
}
}
return axis;
void Reduce::Init(const bool keep_dims, const bool skip_mode) {
this->set_keep_dims(keep_dims);
this->set_skip_mode(skip_mode);
}
MIND_API_OPERATOR_IMPL(Reduce, BaseOperator);

View File

@ -39,7 +39,8 @@ class MIND_API Reduce : public BaseOperator {
/// \brief Method to init the op's attributes.
///
/// \param[in] keep_dims Define whether keep the dims reduced, default false.
void Init(const bool keep_dims = false);
/// \param[in] skip_mode Define whether skip reduce, default false.
void Init(const bool keep_dims = false, const bool skip_mode = false);
/// \brief Method to set keep_dims attribute.
///
@ -51,9 +52,15 @@ class MIND_API Reduce : public BaseOperator {
/// \return keep_dims attribute.
bool get_keep_dims() const;
void set_axis(const std::vector<int64_t> &axis);
/// \brief Method to set skip_mode attribute.
///
/// \param[in] skip_mode Define whether skip reduce, default false.
void set_skip_mode(const bool skip_mode);
std::vector<int64_t> get_axis();
/// \brief Method to get skip_mode attribute.
///
/// \return skip_mode attribute.
bool get_skip_mode() const;
};
abstract::AbstractBasePtr ReduceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);

View File

@ -28,6 +28,7 @@
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kSliceInputNum = 3;
constexpr int64_t kDynamicOutValue = -2;
std::vector<int64_t> InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive, const ValuePtr &input_value) {
std::vector<int64_t> tmp_input;
@ -87,6 +88,7 @@ ShapeVector GetOutputShape(const ShapeVector &input_size_shape, const ShapeVecto
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kSliceInputNum, "Slice inputs num error");
auto input_x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto input_begin_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
auto input_size_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -447,6 +447,7 @@ int64_t StridedSlice::get_shrink_axis_mask() const {
auto value_ptr = GetAttr(kShrinkAxisMask);
return GetValue<int64_t>(value_ptr);
}
std::vector<int64_t> StridedSlice::get_begin() const {
auto value_ptr = GetAttr(kAttrBegin);
return GetValue<std::vector<int64_t>>(value_ptr);
@ -459,6 +460,7 @@ std::vector<int64_t> StridedSlice::get_strides() const {
auto value_ptr = GetAttr(kAttrStrides);
return GetValue<std::vector<int64_t>>(value_ptr);
}
void StridedSlice::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
int64_t shrink_axis_mask) {
this->set_begin_mask(begin_mask);

View File

@ -29,8 +29,6 @@
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
std::vector<int64_t> Transpose::get_perm() {
PrimitivePtr prim = this->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
@ -46,46 +44,31 @@ std::vector<int64_t> Transpose::get_perm() {
return perm;
}
bool CheckAndGetPermValue(const std::vector<AbstractBasePtr> &input_args, ShapeVector *perm_value,
const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(perm_value);
bool is_dynamic = false;
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
ShapeVector CheckAndGetPermValue(const std::vector<AbstractBasePtr> &input_args, const PrimitivePtr &primitive) {
const std::string &op_name = primitive->name();
if (input_args.size() == 1) {
if (!primitive->HasAttr("perm")) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of 'input_perm' is necessary, but missing it.";
}
ValuePtr perm = primitive->GetAttr("perm");
MS_EXCEPTION_IF_NULL(perm);
if (perm->isa<tensor::Tensor>()) {
*perm_value = CheckAndConvertUtils::CheckTensorIntValue("perm", perm, op_name);
} else {
*perm_value = CheckAndConvertUtils::CheckTupleInt("perm", perm, op_name);
}
return is_dynamic;
}
const size_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto input_value = input_args[kInputIndex1]->BuildValue();
if (input_args[kInputIndex1]->isa<abstract::AbstractTuple>()) {
*perm_value = CheckAndConvertUtils::CheckTupleInt("perm", input_value, op_name);
return CheckAndConvertUtils::CheckTupleInt("perm", input_value, op_name);
} else if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
if (input_value->isa<tensor::Tensor>()) {
*perm_value = CheckAndConvertUtils::CheckTensorIntValue("perm", input_value, op_name);
} else {
is_dynamic = true;
return CheckAndConvertUtils::CheckTensorIntValue("perm", input_value, op_name);
}
auto perm_shape = CheckAndConvertUtils::GetTensorInputShape("perm", input_args, 1);
if (perm_shape->shape().size() != 1) {
MS_EXCEPTION(ValueError) << "For 'transpose perm', " << op_name << " must be 1-D, but got"
<< perm_shape->shape().size() << "-D.";
}
}
} else {
MS_EXCEPTION(TypeError) << "For primitive[" << op_name
<< "], the perm must be a tuple or a tensor with all Int elements, but got "
<< input_args[kInputIndex1]->type_name() << ".";
}
return is_dynamic;
return {};
}
class TransposeInfer : public abstract::OpInferBase {
@ -97,7 +80,6 @@ class TransposeInfer : public abstract::OpInferBase {
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name);
ShapeVector p_value;
ShapeVector p_value_raw;
if (x_shape[0] == 0) {
MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0.";
}
@ -105,10 +87,10 @@ class TransposeInfer : public abstract::OpInferBase {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
}
bool perm_is_dynamic = CheckAndGetPermValue(input_args, &p_value_raw, primitive);
if (perm_is_dynamic) {
auto p_value_raw = CheckAndGetPermValue(input_args, primitive);
if (p_value_raw.empty()) {
ShapeVector out_shape;
(void)out_shape.insert(out_shape.end(), x_shape.size(), -1);
(void)out_shape.insert(out_shape.end(), x_shape.size(), abstract::Shape::kShapeDimAny);
return std::make_shared<abstract::Shape>(out_shape);
}

View File

@ -223,16 +223,15 @@ ExpanderPtr GraphKernelExpanderLite::InitExpander(const AnfNodePtr &node) {
auto expander = std::make_shared<DefaultExpander>(Callback::Instance());
ExpanderCreatorFuncList decos = {InferValueDeco::Creator};
std::map<std::string, ExpanderCreatorFuncList> creators = {
{prim::kPrimReduceFusion->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}},
{prim::kPrimExpandDims->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}},
{prim::kPrimReduceFusion->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
{prim::kPrimExpandDims->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
{prim::kPrimUnsqueeze->name(), {FixFormatDeco::Creator}},
{prim::kPrimSqueeze->name(), {FixFormatDeco::Creator}},
{prim::kPrimShape->name(), {FixFormatDeco::Creator}},
{prim::kPrimReshape->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}},
{prim::kPrimConstantOfShape->name(), {InputToAttrDeco::GetCreator({0}), FixFormatDeco::Creator}},
{prim::kPrimTranspose->name(), {TensorToValueDeco::GetCreator({1}), InputToAttrDeco::GetCreator({1})}},
{prim::kPrimGather->name(),
{TensorToValueDeco::GetCreator({2}), InputToAttrDeco::GetCreator({2}), FixFormatDeco::Creator}},
{prim::kPrimReshape->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
{prim::kPrimConstantOfShape->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
{prim::kPrimTranspose->name(), {TensorToValueDeco::GetCreator({1}), InputToAttrDeco::Creator}},
{prim::kPrimGather->name(), {TensorToValueDeco::GetCreator({2}), InputToAttrDeco::Creator, FixFormatDeco::Creator}},
{prim::kPrimConcat->name(), {FixFormatDeco::Creator}},
{prim::kPrimStridedSlice->name(), {FixFormatDeco::Creator}},
{prim::kPrimConv2DFusion->name(), {SubstituteConv2D::Creator}},

View File

@ -149,5 +149,5 @@ def dyn_ones(shape, value_type):
def sum_grad_reduce_axis(x, rx, keep_dims=False):
"""sum the grad_reduce_axis"""
return P.ReduceSum(keep_dims=keep_dims)(x, rx)
"""sum the grad_reduce_axis. when rx is an empty tensor and skip_mode is True, we need skip this reduce"""
return P.ReduceSum(keep_dims=keep_dims, skip_mode=True)(x, rx)

View File

@ -513,7 +513,6 @@ class _Reduce(PrimitiveWithCheck):
value = None
if input_x is not None and axis is not None:
prim_map = {
'ReduceSum': np.sum,
'ReduceMax': np.max,
'ReduceMin': np.min,
'ReduceProd': np.prod,
@ -737,7 +736,7 @@ class CumulativeLogsumexp(Primitive):
validator.check_bool(reverse, "reverse", self.name)
class ReduceSum(_Reduce):
class ReduceSum(PrimitiveWithCheck):
"""
Reduces a dimension of a tensor by summing all elements in the dimension, by default. And also can reduce a
dimension of `x` along the axis. Determine whether the dimensions of the output and input are the same by
@ -746,18 +745,24 @@ class ReduceSum(_Reduce):
Args:
keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
If false, don't keep these dimensions. Default: False.
skip_mode (bool): If true and axis is empty tuple or empty list, the ReduceSum operation isn't performed,
skip it.
If true and axis is other values, the ReduceSum calculation is performed normally.
If false, do reduce. Default: False.
Inputs:
- **x** (Tensor[Number]) - The input tensor. The dtype of the tensor to be reduced is number.
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
Only constant value is allowed. Must be in the range [-rank(`x`), rank(`x`)).
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions
when skip_mode is false. Only constant value is allowed. Must be in the range [-rank(`x`), rank(`x`)).
Outputs:
Tensor, has the same dtype as the `x`.
- If axis is (), and keep_dims is False,
- If axis is (), keep_dims is False, and skip_mode is False,
the output is a 0-D tensor representing the sum of all elements in the input tensor.
- If axis is (), and skip_mode is True,
the ReduceSum operation is not performed, output tensor is equal to the input tensor.
- If axis is int, set as 2, and keep_dims is False,
the shape of output is :math:`(x_1, x_3, ..., x_R)`.
- If axis is tuple(int) or list(int), set as (2, 3), and keep_dims is False,
@ -765,6 +770,7 @@ class ReduceSum(_Reduce):
Raises:
TypeError: If `keep_dims` is not a bool.
TypeError: If `skip_mode` is not a bool.
TypeError: If `x` is not a Tensor.
ValueError: If `axis` is None.
@ -812,12 +818,44 @@ class ReduceSum(_Reduce):
[54.]]]
"""
__mindspore_signature__ = (
sig.make_sig('input_x'),
sig.make_sig('axis', default=())
)
@prim_attr_register
def __init__(self, keep_dims=False):
"""Initialize ReduceSum"""
super(ReduceSum, self).__init__(keep_dims)
def __init__(self, keep_dims=False, skip_mode=False):
"""Initialize Reduce"""
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
validator.check_value_type('skip_mode', skip_mode, [bool], self.name)
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
self.keep_dims = keep_dims
self.skip_mode = skip_mode
self.__setattr_flag__ = True
def __call__(self, x, axis=()):
args = [x, axis]
output = _run_op(self, self.name, args)
return output
def infer_value(self, input_x, axis):
""" return reduce op value"""
value = None
if input_x is not None and axis is not None:
value = input_x.asnumpy()
if isinstance(axis, int):
pass
elif axis:
axis = tuple(set(axis))
elif axis in ((), []) and self.skip_mode:
return input_x
else:
axis = tuple(range(len(value.shape)))
value = np.sum(value, axis, keepdims=self.keep_dims)
value = np.array(value)
value = Tensor(value)
return value
class ReduceAll(_Reduce):
"""