!45367 Graph Kernel adapt convert input attr about graph kernel
Merge pull request !45367 from ZengZitao/convert_input_attr_gk
This commit is contained in:
commit
01e44b24a8
|
@ -1,7 +1,7 @@
|
|||
mindspore.ops.ReduceSum
|
||||
=========================
|
||||
|
||||
.. py:class:: mindspore.ops.ReduceSum(keep_dims=False)
|
||||
.. py:class:: mindspore.ops.ReduceSum(keep_dims=False, skip_mode=False)
|
||||
|
||||
默认情况下,输出Tensor各维度上的和,以达到对所有维度进行归约的目的。也可以对指定维度进行求和归约。
|
||||
|
||||
|
@ -9,15 +9,18 @@ mindspore.ops.ReduceSum
|
|||
|
||||
参数:
|
||||
- **keep_dims** (bool) - 如果为True,则保留计算维度,长度为1。如果为False,则不保留计算维度。默认值:False,输出结果会降低维度。
|
||||
- **skip_mode** (bool) - 如果为True,并且axis为空tuple或空list,不进行ReduceSum计算,axis为其他值,正常运算。如果为False,则正常进行运算。默认值:False。
|
||||
|
||||
输入:
|
||||
- **x** (Tensor[Number]) - ReduceSum的输入,其数据类型为Number。shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。秩应小于8。
|
||||
- **axis** (Union[int, tuple(int), list(int)]) - 要减少的维度。默认值: (),缩小所有维度。只允许常量值,取值范围[-rank(`x`), rank(`x`))。
|
||||
- **axis** (Union[int, tuple(int), list(int)]) - 要减少的维度。默认值: (),当skip_mode为False时,缩小所有维度。只允许常量值,取值范围[-rank(`x`), rank(`x`))。
|
||||
|
||||
输出:
|
||||
Tensor,具有与输入 `x` 相同的shape。
|
||||
|
||||
- 如果轴为(),且keep_dims为False,则输出一个0维Tensor,表示输入Tensor中所有元素的和。
|
||||
- 如果轴为(),且keep_dims为False,skip_mode为False,则输出一个0维Tensor,表示输入Tensor中所有元素的和。
|
||||
|
||||
- 如果轴为(),且skip_mode为True,则不进行ReduceSum运算,输出Tensor等于输入Tensor。
|
||||
|
||||
- 如果轴为int,取值为2,并且keep_dims为False,则输出的shape为 :math:`(x_1, x_3, ..., x_R)` 。
|
||||
|
||||
|
@ -25,5 +28,6 @@ mindspore.ops.ReduceSum
|
|||
|
||||
异常:
|
||||
- **TypeError** - `keep_dims` 不是bool。
|
||||
- **TypeError** - `skip_mode` 不是bool。
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **ValueError** - `axis` 取值为None。
|
|
@ -307,12 +307,7 @@ void InferOp(const CNodePtr &cnode, void *args) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
|
||||
kernel::KernelArgs kernel_args;
|
||||
if (AnfAlgo::IsDynamicShapeSkipExecute(cnode)) {
|
||||
std::vector<TypeId> dtypes{common::AnfAlgo::GetOutputInferDataType(cnode, 0)};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetInputDeviceShape(cnode, 0)}, cnode.get());
|
||||
} else {
|
||||
InferShape(cnode, &kernel_args.depend_tensor_map, args);
|
||||
}
|
||||
|
||||
if (auto kernel_mod_type = kernel_mod->GetKernelModType(); IsCpuGpuKernelMod(kernel_mod_type)) {
|
||||
auto update = kernel::AbstractArgsFromCNode(cnode, IsDeprecatedCpuOrGpuKernelMod(kernel_mod_type));
|
||||
|
|
|
@ -257,12 +257,21 @@ tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_pt
|
|||
return tensor;
|
||||
}
|
||||
|
||||
tensor::TensorPtr CreateEmptyTupleTensor() {
|
||||
std::vector<int64_t> tensor_shape = {0};
|
||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), tensor_shape);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, kInt64};
|
||||
tensor->set_device_info(device_info);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
if (value_tuple->value().empty()) {
|
||||
MS_LOG(WARNING) << "The value tuple is empty.";
|
||||
return nullptr;
|
||||
tensor = CreateEmptyTupleTensor();
|
||||
return tensor;
|
||||
}
|
||||
ValuePtr v = *(value_tuple->value().begin());
|
||||
MS_EXCEPTION_IF_NULL(v);
|
||||
|
|
|
@ -1355,12 +1355,9 @@ void AnfRuntimeAlgorithm::UpdateGraphValidRefPair(const KernelGraphPtr &graph) {
|
|||
graph->set_ref_out_in_map(new_ref_map);
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const std::string &op_name, const ShapeVector &axes_shape) {
|
||||
bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const bool skip_mode, const ShapeVector &axes_shape) {
|
||||
// Skip run ReduceSum when axis is a Empty Tensor
|
||||
if (op_name != kReduceSumOpName) {
|
||||
return false;
|
||||
}
|
||||
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; })) {
|
||||
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -1374,6 +1371,11 @@ bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool skip_mode = false;
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrSkipMode, cnode)) {
|
||||
skip_mode = common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrSkipMode);
|
||||
}
|
||||
|
||||
const size_t axes_index = 1;
|
||||
if (cnode->inputs().size() <= axes_index + 1) {
|
||||
return false;
|
||||
|
@ -1387,7 +1389,7 @@ bool AnfRuntimeAlgorithm::IsDynamicShapeSkipExecute(const CNodePtr &cnode) {
|
|||
MS_EXCEPTION_IF_NULL(axes_abs);
|
||||
auto axes_shape = AnfAlgo::GetInputDeviceShape(cnode, axes_index);
|
||||
if (axes_abs->isa<abstract::AbstractTensor>()) {
|
||||
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; })) {
|
||||
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -172,7 +172,7 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
|
|||
static void CacheAddrForAtomicClean(const AnfNodePtr &node, kernel::KernelMod *kernel_mod);
|
||||
|
||||
static void UpdateGraphValidRefPair(const KernelGraphPtr &graph);
|
||||
static bool IsDynamicShapeSkipExecute(const std::string &op_name, const ShapeVector &axes_shape);
|
||||
static bool IsDynamicShapeSkipExecute(const bool skip_mode, const ShapeVector &axes_shape);
|
||||
static bool IsDynamicShapeSkipExecute(const CNodePtr &cnode);
|
||||
// return true if need to update output's shape and type after launch
|
||||
static bool IsNeedUpdateShapeAndTypeAfterLaunch(const AnfNodePtr &cnode);
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "common/graph_kernel/adapter/callback_impl.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "common/graph_kernel/core/convert_op_input_attr.h"
|
||||
#include "backend/common/pass/inplace_assign_for_custom_op.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -65,7 +66,7 @@ ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
|
|||
{prim::kPrimArgMinWithValue->name(), {ArgWithValueDeco::Creator}},
|
||||
{prim::kPrimSolveTriangular->name(), {ProcessCustomOpDeco::Creator}},
|
||||
{prim::kPrimLU->name(), {ProcessCustomOpDeco::Creator}},
|
||||
{prim::kPrimVmapUnstackAssign->name(), {AttrToInputDeco::GetCreator(true)}},
|
||||
{prim::kPrimVmapUnstackAssign->name(), {AttrToInputDeco::Creator}},
|
||||
};
|
||||
const auto iter = creators.find(GetCNodePrimitive(node)->name());
|
||||
if (iter != creators.end()) {
|
||||
|
@ -178,7 +179,7 @@ AnfNodePtr TryExpandCNode(const AnfNodePtr &node, const std::function<bool(const
|
|||
return nullptr;
|
||||
}
|
||||
auto expander = GetExpander(node);
|
||||
expander = AttrToInputDeco::GetCreator(true)(expander);
|
||||
expander = AttrToInputDeco::Creator(expander);
|
||||
auto res = expander->Run(node);
|
||||
auto expand_fg = GetCNodeFuncGraph(res);
|
||||
if (expand_fg != nullptr) {
|
||||
|
@ -240,49 +241,6 @@ AnfNodePtr SetDynamicShapeAttrDeco::Run(const AnfNodePtr &node) {
|
|||
return new_cnode;
|
||||
}
|
||||
|
||||
void AttrToInputDeco::ConvertAttrToInput(const FuncGraphPtr &graph) {
|
||||
auto todos = TopoSort(graph->get_return());
|
||||
for (const auto &node : todos) {
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (primitive == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::map<std::string, std::set<size_t>> attr2input_map;
|
||||
if (is_backend_) {
|
||||
attr2input_map = std::map<std::string, std::set<size_t>>{{prim::kPrimTupleGetItem->name(), {1}}};
|
||||
} else {
|
||||
attr2input_map = std::map<std::string, std::set<size_t>>{
|
||||
{prim::kPrimCast->name(), {1}}, {prim::kPrimReshape->name(), {1}}, {prim::kPrimReduceMax->name(), {1}},
|
||||
{prim::kPrimReduceMin->name(), {1}}, {prim::kPrimReduceSum->name(), {1}}, {prim::kPrimTranspose->name(), {1}},
|
||||
{prim::kPrimTupleGetItem->name(), {1}}};
|
||||
}
|
||||
if (attr2input_map.count(primitive->name()) != 0) {
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
AnfNodePtrList inputs = cnode->inputs();
|
||||
AnfNodePtrList new_inputs{inputs[0]};
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
auto attrs_map = attr2input_map[primitive->name()];
|
||||
size_t j = 1;
|
||||
for (size_t i = 0; i < input_names_vec.size(); ++i) {
|
||||
if (attrs_map.count(i) != 0) {
|
||||
auto value = primitive->GetAttr(input_names_vec[i]);
|
||||
auto value_node = std::make_shared<ValueNode>(value);
|
||||
value_node->set_abstract(value->ToAbstract());
|
||||
new_inputs.push_back(value_node);
|
||||
} else {
|
||||
if (j >= inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index " << j << " is larger than input size [" << inputs.size() << "]";
|
||||
}
|
||||
new_inputs.push_back(inputs[j]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
cnode->set_inputs(new_inputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr AttrToInputDeco::Run(const AnfNodePtr &node) {
|
||||
auto new_node = decorated_->Run(node);
|
||||
if (new_node == nullptr) {
|
||||
|
@ -290,7 +248,10 @@ AnfNodePtr AttrToInputDeco::Run(const AnfNodePtr &node) {
|
|||
}
|
||||
auto new_cnode = dyn_cast<CNode>(new_node);
|
||||
auto expand_fg = GetCNodeFuncGraph(new_cnode);
|
||||
ConvertAttrToInput(expand_fg);
|
||||
auto todos = TopoSort(expand_fg->get_return());
|
||||
for (const auto &node : todos) {
|
||||
ConvertOpUtils::ConvertAttrToInput(node);
|
||||
}
|
||||
new_cnode->set_input(0, NewValueNode(expand_fg));
|
||||
return new_cnode;
|
||||
}
|
||||
|
|
|
@ -75,19 +75,12 @@ class SetDynamicShapeAttrDeco : public ExpanderDecorator {
|
|||
|
||||
class COMMON_EXPORT AttrToInputDeco : public ExpanderDecorator {
|
||||
public:
|
||||
AttrToInputDeco(const ExpanderPtr &decorated, bool is_backend)
|
||||
: ExpanderDecorator(decorated), is_backend_(is_backend) {}
|
||||
explicit AttrToInputDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
|
||||
~AttrToInputDeco() override = default;
|
||||
static ExpanderCreatorFunc GetCreator(bool is_backend) {
|
||||
return [is_backend](const ExpanderPtr &decorated) {
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<AttrToInputDeco>(decorated, is_backend));
|
||||
};
|
||||
static ExpanderPtr Creator(const ExpanderPtr &decorated) {
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<AttrToInputDeco>(decorated));
|
||||
}
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
||||
protected:
|
||||
void ConvertAttrToInput(const FuncGraphPtr &graph);
|
||||
bool is_backend_{false};
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "mindspore/ccsrc/common/graph_kernel/core/convert_op_input_attr.h"
|
||||
#include "mindspore/ccsrc/common/graph_kernel/core/graph_kernel_callback.h"
|
||||
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
std::map<std::string, HashSet<size_t>> ConvertOpUtils::op_idx_info_ = {{prim::kPrimReshape->name(), {1}},
|
||||
{prim::kPrimReduceMax->name(), {1}},
|
||||
{prim::kPrimExpandDims->name(), {1}},
|
||||
{prim::kPrimReduceMin->name(), {1}},
|
||||
{prim::kPrimReduceSum->name(), {1}},
|
||||
{prim::kPrimTranspose->name(), {1}},
|
||||
{prim::kPrimTile->name(), {1}},
|
||||
{prim::kPrimReduceMean->name(), {1}},
|
||||
{prim::kPrimSlice->name(), {1, 2}},
|
||||
{prim::kPrimStridedSlice->name(), {1, 2, 3}},
|
||||
{prim::kPrimOneHot->name(), {1}},
|
||||
{prim::kPrimReduceFusion->name(), {1}},
|
||||
{prim::kPrimConstantOfShape->name(), {0}},
|
||||
{prim::kPrimGather->name(), {2}},
|
||||
{prim::kPrimTupleGetItem->name(), {1}},
|
||||
{prim::kPrimUnsortedSegmentSum->name(), {2}},
|
||||
{prim::kPrimCumSum->name(), {1}}};
|
||||
bool ConvertOpUtils::ConstInputToAttr(const CNodePtr &cnode, const HashSet<size_t> &input_idx) {
|
||||
AnfNodePtrList new_inputs;
|
||||
auto primitive = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
if (input_names == nullptr) {
|
||||
MS_LOG(INFO) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
|
||||
return false;
|
||||
}
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
auto inputs = cnode->inputs();
|
||||
new_inputs.push_back(inputs[0]);
|
||||
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||
auto input_node = inputs[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_idx.count(i) != 0) {
|
||||
if (i >= input_names_vec.size()) {
|
||||
MS_LOG(INFO) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
|
||||
return false;
|
||||
}
|
||||
ValuePtr value = nullptr;
|
||||
if (input_node->isa<ValueNode>()) {
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
value = value_node->value();
|
||||
} else if (input_node->isa<Parameter>()) {
|
||||
auto parameter_node = input_node->cast<ParameterPtr>();
|
||||
value = parameter_node->abstract()->BuildValue();
|
||||
}
|
||||
if (value != nullptr) {
|
||||
auto value_vector = CheckAndConvertUtils::CheckTensorIntValue(input_names_vec[i], value, primitive->name());
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
auto tensor_shape = tensor->shape_c();
|
||||
if (tensor_shape.empty()) {
|
||||
primitive->set_attr(input_names_vec[i], MakeValue(value_vector[0]));
|
||||
} else {
|
||||
primitive->set_attr(input_names_vec[i], MakeValue(value_vector));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << input_names_vec[i] << "'s Value is null!";
|
||||
}
|
||||
} else {
|
||||
new_inputs.push_back(inputs[i + 1]);
|
||||
}
|
||||
}
|
||||
if (new_inputs.size() != inputs.size()) {
|
||||
cnode->set_inputs(new_inputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void ConvertOpUtils::ConvertAttrToInput(const AnfNodePtr &node) {
|
||||
if (Callback::Instance()->GetTargetFromContext() == kAscendDevice) {
|
||||
return;
|
||||
}
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
auto primitive = GetCNodePrimitive(cnode);
|
||||
if (primitive == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto attr2input_map = ConvertOpUtils::GetOpIndexInfo();
|
||||
if (attr2input_map.count(primitive->name()) != 0) {
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
AnfNodePtrList inputs = cnode->inputs();
|
||||
AnfNodePtrList new_inputs{inputs[0]};
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
auto attrs_map = attr2input_map[primitive->name()];
|
||||
size_t j = 1;
|
||||
for (size_t i = 0; i < input_names_vec.size(); ++i) {
|
||||
if (attrs_map.count(i) != 0) {
|
||||
auto value = primitive->GetAttr(input_names_vec[i]);
|
||||
ValueNodePtr value_node;
|
||||
// Adaptive for TupleGetItem, this op can not make attr to a tensor input.
|
||||
if (primitive->name() == prim::kPrimTupleGetItem->name()) {
|
||||
value_node = std::make_shared<ValueNode>(value);
|
||||
value_node->set_abstract(value->ToAbstract());
|
||||
} else if (value->isa<Scalar>()) {
|
||||
auto value_scalar = GetValue<int64_t>(value);
|
||||
auto tensor_ptr = std::make_shared<tensor::Tensor>(value_scalar, kInt64);
|
||||
value_node = std::make_shared<ValueNode>(tensor_ptr);
|
||||
value_node->set_abstract(tensor_ptr->ToAbstract());
|
||||
} else if (value->isa<ValueTuple>()) {
|
||||
auto value_tuple = GetValue<std::vector<int64_t>>(value);
|
||||
auto tensor_ptr = std::make_shared<tensor::Tensor>(value_tuple, kInt64);
|
||||
value_node = std::make_shared<ValueNode>(tensor_ptr);
|
||||
value_node->set_abstract(tensor_ptr->ToAbstract());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The value should be a scalar or value tuple";
|
||||
}
|
||||
Callback::Instance()->SetEmptyKernelInfo(value_node);
|
||||
new_inputs.push_back(value_node);
|
||||
} else {
|
||||
if (j >= inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index " << j << " is larger than input size [" << inputs.size() << "]";
|
||||
}
|
||||
new_inputs.push_back(inputs[j]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
cnode->set_inputs(new_inputs);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_CONVERT_OP_INPUT_ATTR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_CONVERT_OP_INPUT_ATTR_H_
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/hash_map.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class ConvertOpUtils {
|
||||
public:
|
||||
static bool ConstInputToAttr(const CNodePtr &cnode, const HashSet<size_t> &input_idx);
|
||||
static void ConvertAttrToInput(const AnfNodePtr &node);
|
||||
static std::map<std::string, HashSet<size_t>> &GetOpIndexInfo() { return op_idx_info_; }
|
||||
static std::map<std::string, HashSet<size_t>> op_idx_info_;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_CONVERT_OP_INPUT_ATTR_H_
|
|
@ -20,6 +20,7 @@
|
|||
#include "utils/anf_utils.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "common/graph_kernel/core/convert_op_input_attr.h"
|
||||
#include "common/graph_kernel/expanders/op_desc_registry.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -47,42 +48,15 @@ CNodePtr ExpanderDecorator::QuickCloneCNode(const AnfNodePtr &node, bool clone_p
|
|||
return new_node;
|
||||
}
|
||||
|
||||
bool InputToAttrDeco::ConstInputToAttr(const CNodePtr &cnode) const {
|
||||
AnfNodePtrList new_inputs;
|
||||
auto primitive = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
if (input_names == nullptr) {
|
||||
MS_LOG(INFO) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
|
||||
return false;
|
||||
}
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
auto inputs = cnode->inputs();
|
||||
new_inputs.push_back(inputs[0]);
|
||||
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||
auto input_node = inputs[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_idx_.count(i) != 0 && input_node->isa<ValueNode>()) {
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
if (i >= input_names_vec.size()) {
|
||||
MS_LOG(INFO) << "Index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
|
||||
return false;
|
||||
}
|
||||
auto value = value_node->value();
|
||||
primitive->set_attr(input_names_vec[i], value);
|
||||
} else {
|
||||
new_inputs.push_back(inputs[i + 1]);
|
||||
}
|
||||
}
|
||||
if (new_inputs.size() != inputs.size()) {
|
||||
cnode->set_inputs(new_inputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr InputToAttrDeco::Run(const AnfNodePtr &node) {
|
||||
auto cnode = QuickCloneCNode(node, true);
|
||||
auto ret = ConstInputToAttr(cnode);
|
||||
auto all_op_index_info = ConvertOpUtils::GetOpIndexInfo();
|
||||
auto iter = all_op_index_info.find(GetCNodePrimitive(node)->name());
|
||||
if (iter == all_op_index_info.end()) {
|
||||
MS_LOG(DEBUG) << "This Node don't need convert input to attr";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = ConvertOpUtils::ConstInputToAttr(cnode, iter->second);
|
||||
return ret ? decorated_->Run(cnode) : nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -65,20 +65,12 @@ using ExpanderCreatorFuncList = std::vector<ExpanderCreatorFunc>;
|
|||
|
||||
class COMMON_EXPORT InputToAttrDeco : public ExpanderDecorator {
|
||||
public:
|
||||
InputToAttrDeco(const ExpanderPtr &decorated, const HashSet<size_t> &input_idx)
|
||||
: ExpanderDecorator(decorated), input_idx_(input_idx) {}
|
||||
explicit InputToAttrDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
|
||||
~InputToAttrDeco() = default;
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
||||
static ExpanderCreatorFunc GetCreator(const HashSet<size_t> &input_idx) {
|
||||
return [input_idx](const ExpanderPtr &decorated) {
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<InputToAttrDeco>(decorated, input_idx));
|
||||
};
|
||||
static ExpanderPtr Creator(const ExpanderPtr &decorated) {
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<InputToAttrDeco>(decorated));
|
||||
}
|
||||
|
||||
protected:
|
||||
bool ConstInputToAttr(const CNodePtr &cnode) const;
|
||||
HashSet<size_t> input_idx_;
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -32,5 +32,6 @@ AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const Fu
|
|||
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
|
||||
bool RemoveNonScalarConstTensorFromParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
|
||||
bool EliminateMaketupleGetitem(const FuncGraphPtr &fg);
|
||||
void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs);
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_BUILDER_H_
|
||||
|
|
|
@ -24,12 +24,14 @@
|
|||
#include "utils/hash_map.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "common/graph_kernel/core/convert_op_input_attr.h"
|
||||
#include "common/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -430,12 +432,41 @@ bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
void GraphKernelCluster::ConvertInputToAttrForClusterGraph(const AnfNodePtr &graph_kernel_node) {
|
||||
auto gk_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(graph_kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(gk_graph);
|
||||
auto todos = TopoSort(gk_graph->get_return());
|
||||
for (const auto &n : todos) {
|
||||
if (n == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto node = n->cast<CNodePtr>();
|
||||
if (node != nullptr) {
|
||||
auto all_op_index_info = ConvertOpUtils::GetOpIndexInfo();
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (primitive == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto iter = all_op_index_info.find(primitive->name());
|
||||
if (iter != all_op_index_info.end()) {
|
||||
(void)ConvertOpUtils::ConstInputToAttr(node, iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Redcuce nouesr inputs for new node
|
||||
auto cnode = graph_kernel_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
AnfNodePtrList fn_inputs{inputs};
|
||||
EliminateRedundantParameters(gk_graph, &fn_inputs);
|
||||
cnode->set_inputs(fn_inputs);
|
||||
}
|
||||
void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
|
||||
AnfNodePtrList old_nodes;
|
||||
(void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
|
||||
[this](size_t id) { return this->nodes_[id]; });
|
||||
auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
|
||||
ConvertInputToAttrForClusterGraph(new_node);
|
||||
if (GraphKernelFlags::GetInstance().dump_as_text) {
|
||||
DumpClusterInfo(old_nodes, new_node);
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ class GraphKernelCluster : public opt::Pass {
|
|||
std::vector<size_t> FindCandidates(size_t basenode_id);
|
||||
void RemoveWildGetitem(std::vector<size_t> *candidates);
|
||||
void CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id);
|
||||
void ConvertInputToAttrForClusterGraph(const AnfNodePtr &graph_kernel_node);
|
||||
void DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node);
|
||||
void DumpToFile();
|
||||
void Clean() {
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/anf_utils.h"
|
||||
#include "common/graph_kernel/core/graph_builder.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "common/graph_kernel/core/convert_op_input_attr.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
AnfNodePtr GraphKernelExpander::CreateExpandedNode(const CNodePtr &node, const std::string &name) const {
|
||||
|
@ -51,7 +52,11 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
|||
!AnfUtils::IsRealKernel(node) || !CanExpand(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto all_op_index_info = ConvertOpUtils::GetOpIndexInfo();
|
||||
auto iter = all_op_index_info.find(GetCNodePrimitive(node)->name());
|
||||
if (iter != all_op_index_info.end()) {
|
||||
(void)ConvertOpUtils::ConstInputToAttr(node, iter->second);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Expanding node: " << node->fullname_with_scope();
|
||||
auto newnode = InitExpander(node)->Run(node);
|
||||
if (newnode == nullptr) {
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "kernel/akg/akg_kernel_json_generator.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "common/graph_kernel/core/convert_op_input_attr.h"
|
||||
#include "common/graph_kernel/split_model/split_model_factory.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -367,6 +368,7 @@ class Splitter {
|
|||
private:
|
||||
void ResetInlinedNodesKernelInfo() const {
|
||||
for (const auto &node : inlined_nodes_) {
|
||||
ConvertOpUtils::ConvertAttrToInput(node);
|
||||
Callback::Instance()->ResetKernelInfo(node);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -988,11 +988,21 @@ std::vector<DShape> StandardNormalOp::InferShape(const NodePtrList &, const DAtt
|
|||
}
|
||||
|
||||
void StridedSliceOp::RectifyAbstract(const PrimitivePtr &primitive, AbstractBasePtrList *inputs_abstract) {
|
||||
if (!primitive->HasAttr("new_axis_mask")) {
|
||||
(void)primitive->AddAttr("new_axis_mask", MakeValue((int64_t(0))));
|
||||
}
|
||||
if (!primitive->HasAttr("shrink_axis_mask")) {
|
||||
(void)primitive->AddAttr("shrink_axis_mask", MakeValue((int64_t(0))));
|
||||
}
|
||||
if (!primitive->HasAttr("end_mask")) {
|
||||
(void)primitive->AddAttr("end_mask", MakeValue((int64_t(0))));
|
||||
}
|
||||
if (!primitive->HasAttr("begin_mask")) {
|
||||
(void)primitive->AddAttr("begin_mask", MakeValue((int64_t(0))));
|
||||
}
|
||||
if (!primitive->HasAttr("ellipsis_mask")) {
|
||||
(void)primitive->AddAttr("ellipsis_mask", MakeValue((int64_t(0))));
|
||||
}
|
||||
const mindspore::HashSet<size_t> &convert_input_list{1, 2, 3};
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
|
@ -1181,4 +1191,38 @@ void TupleGetItemOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrLi
|
|||
(void)abs_list->emplace_back(prim->GetAttr("index")->ToAbstract());
|
||||
}
|
||||
}
|
||||
|
||||
void ReduceOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
|
||||
if (prim->HasAttr("axis")) {
|
||||
(void)abs_list->emplace_back(prim->GetAttr("axis")->ToAbstract());
|
||||
}
|
||||
}
|
||||
|
||||
void ReshapeOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
|
||||
if (prim->HasAttr("shape")) {
|
||||
(void)abs_list->emplace_back(prim->GetAttr("shape")->ToAbstract());
|
||||
} else if (prim->HasAttr("input_shape")) {
|
||||
(void)abs_list->emplace_back(prim->GetAttr("input_shape")->ToAbstract());
|
||||
}
|
||||
}
|
||||
|
||||
void TransposeOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
|
||||
if (prim->HasAttr("input_perm")) {
|
||||
(void)abs_list->emplace_back(prim->GetAttr("input_perm")->ToAbstract());
|
||||
} else if (prim->HasAttr("perm")) {
|
||||
(void)abs_list->emplace_back(prim->GetAttr("perm")->ToAbstract());
|
||||
}
|
||||
}
|
||||
|
||||
void TileOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
|
||||
if (prim->HasAttr("multiples")) {
|
||||
(void)abs_list->emplace_back(prim->GetAttr("multiples")->ToAbstract());
|
||||
}
|
||||
}
|
||||
|
||||
void OneHotOp::RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) {
|
||||
if (prim->HasAttr("depth")) {
|
||||
(void)abs_list->emplace_back(prim->GetAttr("depth")->ToAbstract());
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::graphkernel::inner
|
||||
|
|
|
@ -96,6 +96,7 @@ class ReshapeOp : public PrimOp {
|
|||
~ReshapeOp() = default;
|
||||
|
||||
protected:
|
||||
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list);
|
||||
DFormat InferFormat(const NodePtrList &, const DAttrs &attrs) override {
|
||||
return attrs.find("format") == attrs.end() ? kOpFormat_DEFAULT
|
||||
: GetValue<std::string>(attrs.find("format")->second);
|
||||
|
@ -118,12 +119,22 @@ class BroadcastOp : public PrimOp {
|
|||
~BroadcastOp() = default;
|
||||
};
|
||||
|
||||
class TileOp : public BroadcastOp {
|
||||
public:
|
||||
explicit TileOp(const std::string &op) : BroadcastOp(op) {}
|
||||
~TileOp() = default;
|
||||
|
||||
protected:
|
||||
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) override;
|
||||
};
|
||||
|
||||
class ReduceOp : public PrimOp {
|
||||
public:
|
||||
explicit ReduceOp(const std::string &op) : PrimOp(op, ComputeType::REDUCE) {}
|
||||
~ReduceOp() = default;
|
||||
|
||||
protected:
|
||||
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) override;
|
||||
DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; };
|
||||
};
|
||||
|
||||
|
@ -159,9 +170,18 @@ class TransposeOp : public OpaqueOp {
|
|||
~TransposeOp() = default;
|
||||
|
||||
protected:
|
||||
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) override;
|
||||
DFormat InferFormat(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
};
|
||||
|
||||
class OneHotOp : public OpaqueOp {
|
||||
public:
|
||||
explicit OneHotOp(const std::string &op) : OpaqueOp(op) {}
|
||||
~OneHotOp() = default;
|
||||
|
||||
protected:
|
||||
void RectifyAbstract(const PrimitivePtr &prim, AbstractBasePtrList *abs_list) override;
|
||||
};
|
||||
class LayoutTransformOp : public OpaqueOp {
|
||||
public:
|
||||
explicit LayoutTransformOp(const std::string &op) : OpaqueOp(op) {}
|
||||
|
|
|
@ -109,18 +109,18 @@ AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node) {
|
|||
}
|
||||
auto expander = graphkernel::GetExpander(node);
|
||||
std::map<std::string, graphkernel::ExpanderCreatorFuncList> creators = {
|
||||
{prim::kPrimExpandDims->name(), {InputToAttrDeco::GetCreator({1})}},
|
||||
{prim::kPrimReshape->name(), {InputToAttrDeco::GetCreator({1})}},
|
||||
{prim::kPrimReduceMean->name(), {InputToAttrDeco::GetCreator({1})}},
|
||||
{prim::kPrimGather->name(), {InputToAttrDeco::GetCreator({2})}},
|
||||
{kTileOpName, {InputToAttrDeco::GetCreator({1})}},
|
||||
{kSliceOpName, {InputToAttrDeco::GetCreator({1, 2})}},
|
||||
{prim::kPrimExpandDims->name(), {InputToAttrDeco::Creator}},
|
||||
{prim::kPrimReshape->name(), {InputToAttrDeco::Creator}},
|
||||
{prim::kPrimReduceMean->name(), {InputToAttrDeco::Creator}},
|
||||
{prim::kPrimGather->name(), {InputToAttrDeco::Creator}},
|
||||
{kTileOpName, {InputToAttrDeco::Creator}},
|
||||
{kSliceOpName, {InputToAttrDeco::Creator}},
|
||||
};
|
||||
const auto &iter = creators.find(GetCNodePrimitive(node)->name());
|
||||
if (iter != creators.end()) {
|
||||
expander = graphkernel::WrapExpander(expander, iter->second);
|
||||
}
|
||||
expander = graphkernel::AttrToInputDeco::GetCreator(false)(expander);
|
||||
expander = graphkernel::AttrToInputDeco::Creator(expander);
|
||||
expander = PrimToPrimPyDecorator::Creator(expander);
|
||||
auto new_node = expander->Run(node);
|
||||
auto expand_fg = GetCNodeFuncGraph(new_node);
|
||||
|
|
|
@ -481,6 +481,7 @@ constexpr auto kAttrReshapeType = "reshape_type";
|
|||
constexpr auto kAttrAxis = "axis";
|
||||
constexpr auto kAttrAxes = "axes";
|
||||
constexpr auto kAttrKeepDims = "keep_dims";
|
||||
constexpr auto kAttrSkipMode = "skip_mode";
|
||||
constexpr auto kAttrShapeGamma = "shape_gamma";
|
||||
constexpr auto kAttrPerm = "perm";
|
||||
constexpr auto kAttrTransposeFirst = "transpose_first";
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "include/common/utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -76,6 +77,7 @@ class CNodeDecoder {
|
|||
}
|
||||
CreateKernelInfo(processor);
|
||||
CreateAbstract();
|
||||
InitIOName(cnode_);
|
||||
return cnode_;
|
||||
}
|
||||
|
||||
|
@ -109,6 +111,21 @@ class CNodeDecoder {
|
|||
}
|
||||
}
|
||||
|
||||
void InitIOName(const CNodePtr &cnode) {
|
||||
auto primitive = GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const auto &op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
|
||||
auto const iter = op_primc_fns.find(primitive->name());
|
||||
if (iter == op_primc_fns.end()) {
|
||||
return;
|
||||
}
|
||||
auto prim = iter->second();
|
||||
if (prim != nullptr) {
|
||||
(void)primitive->AddAttr(kAttrInputNames, prim->GetAttr(kAttrInputNames));
|
||||
(void)primitive->AddAttr(kAttrOutputNames, prim->GetAttr(kAttrOutputNames));
|
||||
}
|
||||
}
|
||||
|
||||
bool DecodeAttrs(const nlohmann::json &attrs_json) {
|
||||
MS_LOG(DEBUG) << "start decode attrs, " << attrs_json;
|
||||
// attrs maybe empty
|
||||
|
|
|
@ -50,7 +50,7 @@ void EmbeddingLookUpCommGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node
|
|||
if (IsDynamic(input_shape_)) {
|
||||
return;
|
||||
}
|
||||
if (input_shape_.size() < 1) {
|
||||
if (input_shape_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of input must be at least 1-D, but got: " << input_shape_.size() << "-D";
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kEmbeddingLookupInputsNum = 3;
|
||||
constexpr size_t kEmbeddingLookUpInputParamsMaxDim = 2;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kOffsetIndex = 2;
|
||||
using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc;
|
||||
|
||||
#define ADD_KERNEL(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, input_indices_type) \
|
||||
|
@ -152,7 +152,7 @@ bool EmbeddingLookUpCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
|
|||
T *input_params_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
S *input_indices_addr = reinterpret_cast<S *>(inputs[1]->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
G offset = static_cast<G *>(inputs[kIndex2]->addr)[0];
|
||||
G offset = static_cast<G *>(inputs[kOffsetIndex]->addr)[0];
|
||||
offset_ = static_cast<int64_t>(offset);
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
|
|
|
@ -112,70 +112,8 @@ std::vector<KernelAttr> MemcpyCpuKernelMod::common_two_valid_types_with_bool_com
|
|||
|
||||
std::vector<KernelAttr> MemcpyCpuKernelMod::GetOpSupport() {
|
||||
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
|
||||
{kReshape,
|
||||
{
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
// double input
|
||||
// index int64
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
// index int32
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
// complex
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
}},
|
||||
{kFlatten, common_valid_types_with_bool_complex_},
|
||||
{kFlattenGrad, common_two_valid_types_with_bool_complex_},
|
||||
{kExpandDims, common_two_valid_types_with_bool_complex_},
|
||||
{kReshape, common_two_valid_types_with_bool_complex_}, {kFlatten, common_valid_types_with_bool_complex_},
|
||||
{kFlattenGrad, common_two_valid_types_with_bool_complex_}, {kExpandDims, common_two_valid_types_with_bool_complex_},
|
||||
{kSqueeze, common_valid_types_with_bool_complex_},
|
||||
};
|
||||
|
||||
|
|
|
@ -81,6 +81,7 @@ class ReduceCpuKernelFunc : public CpuKernelFunc {
|
|||
bool simple_execute_{false};
|
||||
std::string kernel_name_;
|
||||
bool need_skip_execute_{false};
|
||||
bool skip_mode_{false};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -234,13 +235,11 @@ int ReduceCpuKernelFunc<T>::Resize(const BaseOperatorPtr &base_operator, const s
|
|||
const std::vector<KernelTensorPtr> &,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
input_shape_ = inputs[0]->GetDeviceShapeAdaptively();
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
|
||||
if (kernel_ptr->HasAttr(kAttrAxis)) {
|
||||
axis_ = kernel_ptr->get_axis();
|
||||
if (!TryGetIntValue(inputs, kIndex1, kernel_name_, &axis_, false)) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << " can't get axis input! ";
|
||||
}
|
||||
(void)TryGetIntValue(inputs, kAxisIndex_, kernel_name_, &axis_, false);
|
||||
if (inputs.size() > kAxisIndex_ &&
|
||||
AnfAlgo::IsDynamicShapeSkipExecute(kernel_name_, inputs[kAxisIndex_]->GetShapeVector())) {
|
||||
AnfAlgo::IsDynamicShapeSkipExecute(skip_mode_, inputs[kAxisIndex_]->GetShapeVector())) {
|
||||
need_skip_execute_ = true;
|
||||
} else {
|
||||
need_skip_execute_ = false;
|
||||
|
@ -302,6 +301,8 @@ void ReduceCpuKernelFunc<T>::InitFunc(const BaseOperatorPtr &base_operator, cons
|
|||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
kernel_name_ = base_operator->name();
|
||||
ChooseFunc(kernel_name_);
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
|
||||
skip_mode_ = kernel_ptr->get_skip_mode();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -442,20 +443,8 @@ using SpecializeReduceFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>
|
|||
#define REDUCE_CPU_REG(MS_T, MS_S, T) \
|
||||
KernelAttr().AddInputAttr(MS_T).AddInputAttr(MS_S).AddOutputAttr(MS_T), SpecializeReduceFunc<T>
|
||||
static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_all_any_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeReduceFunc<bool>},
|
||||
{REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt32, bool)},
|
||||
{REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt64, bool)}};
|
||||
{REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt32, bool)}, {REDUCE_CPU_REG(kNumberTypeBool, kNumberTypeInt64, bool)}};
|
||||
static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_max_min_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), SpecializeReduceFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), SpecializeReduceFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), SpecializeReduceFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), SpecializeReduceFunc<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SpecializeReduceFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), SpecializeReduceFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), SpecializeReduceFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), SpecializeReduceFunc<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), SpecializeReduceFunc<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), SpecializeReduceFunc<uint64_t>},
|
||||
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{REDUCE_CPU_REG(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
|
@ -477,20 +466,6 @@ static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_ma
|
|||
{REDUCE_CPU_REG(kNumberTypeUInt64, kNumberTypeInt32, uint64_t)},
|
||||
{REDUCE_CPU_REG(kNumberTypeUInt64, kNumberTypeInt64, uint64_t)}};
|
||||
static std::vector<std::pair<KernelAttr, SpecializeReduceFuncCreator>> kernel_sum_prod_mean_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), SpecializeReduceFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), SpecializeReduceFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), SpecializeReduceFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), SpecializeReduceFunc<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SpecializeReduceFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), SpecializeReduceFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), SpecializeReduceFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), SpecializeReduceFunc<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), SpecializeReduceFunc<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), SpecializeReduceFunc<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
SpecializeReduceFunc<complex64>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
SpecializeReduceFunc<complex128>},
|
||||
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{REDUCE_CPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{REDUCE_CPU_REG(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
|
|
|
@ -25,8 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kSliceInputsNum = 1;
|
||||
constexpr size_t kSliceDynamicInputNum = 3;
|
||||
constexpr size_t kSliceInputsNum = 3;
|
||||
constexpr size_t kSliceOutputsNum = 1;
|
||||
constexpr size_t kSliceInputIndex2 = 2;
|
||||
constexpr size_t kSliceTwoDims = 2;
|
||||
|
@ -59,12 +58,23 @@ bool SliceCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve
|
|||
{kNumberTypeFloat16, sizeof(float16)},
|
||||
{kNumberTypeComplex64, sizeof(std::complex<float>)},
|
||||
{kNumberTypeComplex128, sizeof(std::complex<double>)}};
|
||||
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num != kSliceInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "input size should be " << kSliceInputsNum;
|
||||
}
|
||||
|
||||
TypeId dtype = inputs[0]->GetDtype();
|
||||
auto size_pair = type_size_map.find(dtype);
|
||||
if (size_pair == type_size_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Slice supports type in type_size_map, but got " << TypeIdToType(dtype)->ToString();
|
||||
}
|
||||
data_size_ = size_pair->second;
|
||||
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(inputs.at(1)->GetDtype() == inputs[kSliceInputIndex2]->GetDtype(),
|
||||
"Begin and size dtype should be same.");
|
||||
param_dtype_ = inputs.at(1)->GetDtype();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -103,11 +113,17 @@ void SliceCpuKernelMod::InitSliceParam(const ShapeVector &input_shape, const std
|
|||
int SliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Slice>(base_operator);
|
||||
is_got_value_ = false;
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num != kSliceInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "input size should be " << kSliceInputsNum;
|
||||
}
|
||||
|
||||
auto input_shape = inputs[0]->GetShapeVector();
|
||||
if (input_shape.size() > DIMENSION_8D || input_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
|
@ -115,23 +131,26 @@ int SliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::v
|
|||
<< "D.";
|
||||
}
|
||||
|
||||
size_t input_num = inputs.size();
|
||||
// begin and size are const input
|
||||
if (input_num == 1) {
|
||||
auto size = kernel_ptr->get_size();
|
||||
auto begin = kernel_ptr->get_begin();
|
||||
std::vector<int64_t> begin;
|
||||
std::vector<int64_t> size;
|
||||
auto got_begin = TryGetIntValue(inputs, 1, kernel_name_, &begin, false);
|
||||
auto got_size = TryGetIntValue(inputs, kSliceInputIndex2, kernel_name_, &size, false);
|
||||
if (got_begin && got_size) {
|
||||
if (begin.size() != input_shape.size() || size.size() != input_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the lengths of 'begin' and 'size' must be equal to "
|
||||
"the dimension of input tensor, but got the length of 'begin' "
|
||||
<< begin.size() << ", the length of 'size' " << size.size()
|
||||
<< "and the dimension of input tensor " << input_shape.size();
|
||||
<< begin.size() << ", the length of 'size' " << size.size() << "and the dimension of input tensor "
|
||||
<< input_shape.size();
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
InitSliceParam(input_shape, begin, size);
|
||||
} else if (input_num == kSliceDynamicInputNum) {
|
||||
is_got_value_ = true;
|
||||
} else {
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
begin_shape_ = inputs[1]->GetShapeVector();
|
||||
size_shape_ = inputs[kSliceInputIndex2]->GetShapeVector();
|
||||
is_got_value_ = false;
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
@ -150,45 +169,63 @@ void SliceSimpleDim2(const int8_t *input, int8_t *output, const SliceParameter *
|
|||
|
||||
bool SliceCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != kSliceInputsNum && inputs.size() != kSliceDynamicInputNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kSliceInputsNum << " or "
|
||||
<< kSliceDynamicInputNum << ", but got " << inputs.size() << " input(s).";
|
||||
if (inputs.size() != kSliceInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kSliceInputsNum
|
||||
<< ", but got " << inputs.size() << " input(s).";
|
||||
}
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSliceOutputsNum, kernel_name_);
|
||||
|
||||
auto input_addr = inputs[0]->addr;
|
||||
auto output_addr = outputs[0]->addr;
|
||||
if (inputs.size() == kSliceDynamicInputNum) {
|
||||
if (!is_got_value_) {
|
||||
// Get begin and size value.
|
||||
auto input_shape = input_shape_;
|
||||
auto begin_shape = begin_shape_;
|
||||
auto size_shape = size_shape_;
|
||||
if (begin_shape.size() != 1 || size_shape.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the dimensions of 'begin' and 'size' must be 1, but got the dimension of 'begin': "
|
||||
<< begin_shape.size() << " and the dimension of 'size': " << size_shape.size();
|
||||
return false;
|
||||
}
|
||||
if (begin_shape[0] != SizeToLong(input_shape.size()) || size_shape[0] != SizeToLong(input_shape.size())) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the lengths of 'begin' and 'size' must be equal to "
|
||||
"the dimension of input tensor, but got the length of 'begin' "
|
||||
<< begin_shape[0] << ", the length of 'size' " << size_shape[0]
|
||||
<< "and the dimension of input tensor " << input_shape.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64_t> begin;
|
||||
std::vector<int64_t> size;
|
||||
if (param_dtype_ == kNumberTypeInt32) {
|
||||
auto begin_ptr = reinterpret_cast<int32_t *>(inputs[1]->addr);
|
||||
auto size_ptr = reinterpret_cast<int32_t *>(inputs[kSliceInputIndex2]->addr);
|
||||
std::vector<int64_t> begin{begin_ptr, begin_ptr + begin_shape[0]};
|
||||
std::vector<int64_t> size{size_ptr, size_ptr + size_shape[0]};
|
||||
begin.assign(begin_ptr, begin_ptr + begin_shape[0]);
|
||||
size.assign(size_ptr, size_ptr + size_shape[0]);
|
||||
} else if (param_dtype_ == kNumberTypeInt64) {
|
||||
auto begin_ptr = reinterpret_cast<int64_t *>(inputs[1]->addr);
|
||||
auto size_ptr = reinterpret_cast<int64_t *>(inputs[kSliceInputIndex2]->addr);
|
||||
begin.assign(begin_ptr, begin_ptr + begin_shape[0]);
|
||||
size.assign(size_ptr, size_ptr + size_shape[0]);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Invalid dtype: " << TypeIdLabel(param_dtype_);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < begin.size(); ++i) {
|
||||
if (input_shape[i] < begin[i] + size[i]) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', slice shape should be not greater than origin shape. But in dimension i=" << i
|
||||
<< ", origin shape 'input_shape[i]' is " << input_shape[i]
|
||||
<< " and slice shape 'LongToSize(begin[i] + size[i])' is " << LongToSize(begin[i] + size[i]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
InitSliceParam(input_shape, begin, size);
|
||||
}
|
||||
|
||||
auto input_addr = inputs[0]->addr;
|
||||
auto output_addr = outputs[0]->addr;
|
||||
if (origin_dim_size_ == kSliceTwoDims) {
|
||||
auto task = [this, &input_addr, &output_addr](size_t start, size_t end) {
|
||||
auto src =
|
||||
|
@ -204,6 +241,22 @@ bool SliceCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
|
|||
return true;
|
||||
}
|
||||
|
||||
#define SLICE_CPU_REGISTER_KERNEL_ATTR(DT) \
|
||||
KernelAttr().AddInputAttr(DT).AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(DT), \
|
||||
KernelAttr().AddInputAttr(DT).AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(DT)
|
||||
|
||||
std::vector<KernelAttr> SliceCpuKernelMod::GetOpSupport() {
|
||||
static const std::vector<KernelAttr> support_list = {
|
||||
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeBool), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeInt8),
|
||||
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeInt16), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeInt32),
|
||||
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeInt64), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeUInt8),
|
||||
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeUInt16), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeUInt32),
|
||||
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeUInt64), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeFloat16),
|
||||
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeFloat32), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeFloat64),
|
||||
SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeComplex64), SLICE_CPU_REGISTER_KERNEL_ATTR(kNumberTypeComplex128)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Slice, SliceCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SLICE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SLICE_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -39,39 +39,7 @@ class SliceCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static const std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool)};
|
||||
return support_list;
|
||||
}
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitSliceParam(const ShapeVector &input_shape, const std::vector<int64_t> &begin,
|
||||
|
@ -80,10 +48,12 @@ class SliceCpuKernelMod : public NativeCpuKernelMod {
|
|||
int data_size_{4};
|
||||
SliceParameter slice_param_;
|
||||
size_t output_num_{1};
|
||||
TypeId param_dtype_{kNumberTypeInt32};
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> begin_shape_;
|
||||
std::vector<int64_t> size_shape_;
|
||||
bool is_got_value_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SLICE_CPU_KERNEL_H_
|
||||
|
|
|
@ -28,8 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kStridedSliceInputsNum = 1;
|
||||
constexpr size_t kStridedSliceDynamicInputsNum = 4;
|
||||
constexpr size_t kStridedSliceInputsNum = 4;
|
||||
constexpr size_t kStridedSliceOutputsNum = 1;
|
||||
|
||||
using complex64 = std::complex<float>;
|
||||
|
@ -58,10 +57,11 @@ int StridedSliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (inputs.size() != kStridedSliceInputsNum && inputs.size() != kStridedSliceDynamicInputsNum) {
|
||||
if (inputs.size() != kStridedSliceInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kStridedSliceInputsNum
|
||||
<< " or " << kStridedSliceDynamicInputsNum << ", but got " << inputs.size();
|
||||
<< ", but got " << inputs.size();
|
||||
}
|
||||
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceOutputsNum, kernel_name_);
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
dtype_ = inputs[0]->GetDtype();
|
||||
|
@ -76,7 +76,6 @@ int StridedSliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
InitParallelParam();
|
||||
}
|
||||
|
||||
if (inputs.size() == kStridedSliceDynamicInputsNum) {
|
||||
// for begin, end, stride are not const input
|
||||
begin_shape_ = inputs[kIndex1]->GetShapeVector();
|
||||
end_shape_ = inputs[kIndex2]->GetShapeVector();
|
||||
|
@ -88,14 +87,7 @@ int StridedSliceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
<< begin_shape_.size() << ", the dimension of 'end': " << end_shape_.size()
|
||||
<< ", and the dimension of 'strides': " << stride_shape_.size();
|
||||
}
|
||||
} else {
|
||||
// for begin, end, stride are tuples
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::StridedSlice>(base_operator);
|
||||
auto begin = kernel_ptr->get_begin();
|
||||
auto end = kernel_ptr->get_end();
|
||||
auto stride = kernel_ptr->get_strides();
|
||||
InitSliceParam(base_operator, &begin, &end, &stride);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -244,16 +236,14 @@ template <typename T, typename S>
|
|||
bool StridedSliceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /* workspace */,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != kStridedSliceInputsNum && inputs.size() != kStridedSliceDynamicInputsNum) {
|
||||
if (inputs.size() != kStridedSliceInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kStridedSliceInputsNum
|
||||
<< " or " << kStridedSliceDynamicInputsNum << ", but got " << inputs.size();
|
||||
<< ", but got " << inputs.size();
|
||||
}
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kStridedSliceOutputsNum, kernel_name_);
|
||||
auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr);
|
||||
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num == kStridedSliceDynamicInputsNum) {
|
||||
// for begin, end, stride are tensors
|
||||
std::vector<int64_t> begin;
|
||||
std::vector<int64_t> end;
|
||||
|
@ -271,7 +261,6 @@ bool StridedSliceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr
|
|||
stride.push_back(static_cast<int64_t>(strides_ptr[i]));
|
||||
}
|
||||
InitSliceParam(base_operator_, &begin, &end, &stride);
|
||||
}
|
||||
|
||||
int thread_num = slice_param_.op_parameter_.thread_num_;
|
||||
if (parallel_ && thread_num >= 2) {
|
||||
|
|
|
@ -17,13 +17,14 @@
|
|||
#include "plugin/device/cpu/kernel/tile_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kTileInputsNum = 1;
|
||||
constexpr size_t kTileDynamicInputsNum = 2;
|
||||
constexpr size_t kTileInputsNum = 2;
|
||||
constexpr size_t kTileOutputsNum = 1;
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
|
@ -84,18 +85,14 @@ void TileCpuKernelMod::TileMultipleCompute() {
|
|||
bool TileCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
kernel_name_ = base_operator->name();
|
||||
input_num = inputs.size();
|
||||
dtype_ = inputs[kIndex0]->GetDtype();
|
||||
multiples_.clear();
|
||||
if (input_num == kTileInputsNum) {
|
||||
std::vector<int64_t> multiples_me = GetValue<std::vector<int64_t>>(prim->GetAttr("multiples"));
|
||||
(void)std::transform(multiples_me.begin(), multiples_me.end(), std::back_inserter(multiples_),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
input_num_ = inputs.size();
|
||||
if (input_num_ != kTileInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "Tile's inputs number should be " << kTileInputsNum << ", but got " << input_num_;
|
||||
}
|
||||
|
||||
multiples_.clear();
|
||||
dtype_ = inputs[kIndex0]->GetDtype();
|
||||
launch_map_[kNumberTypeInt8] = &TileCpuKernelMod::LaunchKernel<int8_t>;
|
||||
launch_map_[kNumberTypeInt16] = &TileCpuKernelMod::LaunchKernel<int16_t>;
|
||||
launch_map_[kNumberTypeInt32] = &TileCpuKernelMod::LaunchKernel<int>;
|
||||
|
@ -129,20 +126,20 @@ int TileCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
|
|||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
x_shape_ = inputs[kIndex0]->GetShapeVector();
|
||||
y_shape_ = outputs[kIndex0]->GetShapeVector();
|
||||
if (input_num == kTileDynamicInputsNum) {
|
||||
multiple_shape = inputs[kIndex1]->GetShapeVector();
|
||||
multiple_shape_ = inputs[kIndex1]->GetShapeVector();
|
||||
multiple_dtype_ = inputs[kIndex1]->GetDtype();
|
||||
}
|
||||
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool TileCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != kTileInputsNum && inputs.size() != kTileDynamicInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of input must be " << kTileInputsNum << " or "
|
||||
<< kTileDynamicInputsNum << ", but got " << inputs.size();
|
||||
if (inputs.size() != kTileInputsNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of input must be " << kTileInputsNum << ", but got "
|
||||
<< inputs.size();
|
||||
}
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTileOutputsNum, kernel_name_);
|
||||
launch_func_(this, inputs, outputs);
|
||||
|
@ -153,15 +150,8 @@ template <typename T>
|
|||
void TileCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
auto x_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto y_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
if (input_num == kTileInputsNum) {
|
||||
TileMultipleCompute();
|
||||
}
|
||||
if (input_num == kTileDynamicInputsNum) {
|
||||
multiples_.clear();
|
||||
int64_t multiple_nums = 1;
|
||||
for (size_t i = 0; i < multiple_shape.size(); ++i) {
|
||||
multiple_nums *= multiple_shape[i];
|
||||
}
|
||||
auto multiple_nums = std::accumulate(multiple_shape_.begin(), multiple_shape_.end(), 1, std::multiplies<int64_t>());
|
||||
if (multiple_dtype_ == kNumberTypeInt32) {
|
||||
auto multiples_addr = GetDeviceAddress<int32_t>(inputs, 1);
|
||||
for (size_t i = 0; i < LongToSize(multiple_nums); ++i) {
|
||||
|
@ -174,7 +164,6 @@ void TileCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const
|
|||
}
|
||||
}
|
||||
TileMultipleCompute();
|
||||
}
|
||||
|
||||
tile_parameter_.data_size_ = sizeof(T);
|
||||
if (one_dim_tile_) {
|
||||
|
|
|
@ -112,9 +112,9 @@ class TileCpuKernelMod : public NativeCpuKernelMod {
|
|||
ShapeVector x_shape_;
|
||||
ShapeVector y_shape_;
|
||||
ShapeVector multiple_shape;
|
||||
size_t input_num;
|
||||
size_t output_num;
|
||||
size_t input_num_;
|
||||
std::vector<int> multiples_;
|
||||
ShapeVector multiple_shape_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId multiple_dtype_{kTypeUnknown};
|
||||
|
||||
|
|
|
@ -27,8 +27,7 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kTransposeInputsNum = 1;
|
||||
constexpr size_t kDynamicPermInputNum = 2;
|
||||
constexpr size_t kTransposeInputNum = 2;
|
||||
constexpr size_t kTransposeOutputsNum = 1;
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
@ -86,6 +85,9 @@ int TransposeFwdCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputNum, kernel_name_);
|
||||
|
||||
input_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
output_shape_ = outputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
dtype_ = inputs[kIndex0]->GetDtype();
|
||||
|
@ -98,23 +100,24 @@ int TransposeFwdCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
strides_[i - 1] = input_shape_[i] * strides_[i];
|
||||
out_strides_[i - 1] = output_shape_[i] * out_strides_[i];
|
||||
}
|
||||
if (inputs.size() == kDynamicPermInputNum) {
|
||||
|
||||
perm_type_ = inputs[kIndex1]->GetDtype();
|
||||
perm_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||
return KRET_OK;
|
||||
}
|
||||
auto prim = std::dynamic_pointer_cast<ops::Transpose>(base_operator);
|
||||
MS_ERROR_IF_NULL(prim);
|
||||
perm_ = prim->get_perm();
|
||||
|
||||
perm_.clear();
|
||||
got_perm_value_ = TryGetIntValue(inputs, kIndex1, kernel_name_, &perm_, false);
|
||||
if (got_perm_value_) {
|
||||
CheckPermValue();
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool TransposeFwdCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_);
|
||||
if (inputs.size() == kDynamicPermInputNum) {
|
||||
if (!got_perm_value_) {
|
||||
if (perm_type_ == kNumberTypeInt32) {
|
||||
InitPerm<int32_t>(inputs);
|
||||
} else {
|
||||
|
@ -126,34 +129,6 @@ bool TransposeFwdCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, TransposeFwdCpuKernelMod::TypeKernel>> TransposeFwdCpuKernelMod::launch_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<complex64>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<complex128>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
|
@ -165,7 +140,7 @@ std::vector<std::pair<KernelAttr, TransposeFwdCpuKernelMod::TypeKernel>> Transpo
|
|||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
|
@ -193,7 +168,7 @@ std::vector<std::pair<KernelAttr, TransposeFwdCpuKernelMod::TypeKernel>> Transpo
|
|||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<int64_t>},
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
&TransposeFwdCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
|
|
|
@ -80,6 +80,7 @@ class TransposeFwdCpuKernelMod : public NativeCpuKernelMod {
|
|||
size_t data_num_{0};
|
||||
std::vector<int64_t> strides_;
|
||||
std::vector<int64_t> out_strides_;
|
||||
bool got_perm_value_{false};
|
||||
|
||||
using TypeKernel =
|
||||
std::function<void(TransposeFwdCpuKernelMod *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -48,6 +48,25 @@ std::vector<int64_t> TransposeAxis(const int dim, FormatTransformDir dir) {
|
|||
return axis;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateValueNode(const std::vector<int64_t> &transpose_perm, const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto tensor_ptr = std::make_shared<tensor::Tensor>(transpose_perm, TypeIdToType(kNumberTypeInt64));
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
auto value_node = std::make_shared<ValueNode>(tensor_ptr);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
value_node->set_abstract(tensor_ptr->ToAbstract());
|
||||
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(graph);
|
||||
if (kernel_graph != nullptr) {
|
||||
value_node = kernel_graph->NewValueNode(value_node);
|
||||
kernel_graph->AddValueNodeToGraph(value_node);
|
||||
} else {
|
||||
value_node = MakeValueNode(value_node);
|
||||
}
|
||||
|
||||
return value_node;
|
||||
}
|
||||
|
||||
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
|
||||
int used_node_index, const std::vector<int64_t> &transpose_perm) {
|
||||
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
||||
|
@ -58,13 +77,13 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
|||
auto transpose_prim = std::make_shared<Primitive>(primitive_ptr->name());
|
||||
MS_EXCEPTION_IF_NULL(transpose_prim);
|
||||
// 2.Set the input of transpose.
|
||||
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
|
||||
auto perm_value_node = CreateValueNode(transpose_perm, graph);
|
||||
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node, perm_value_node};
|
||||
auto transpose_op = graph->NewCNode(transpose_input);
|
||||
// 3.Set the output info of transpose.
|
||||
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, IntToSize(used_node_index))};
|
||||
auto transpose_shape = {common::AnfAlgo::GetPrevNodeOutputDetailShape(used_node, IntToSize(used_node_index))};
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, transpose_shape, transpose_op.get());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
||||
// 4. Set the new edge of transpose op.
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -78,8 +97,8 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
|
|||
auto input_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
auto output_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({input_format});
|
||||
builder.SetInputsDeviceType({input_type});
|
||||
builder.SetInputsFormat({input_format, kOpFormat_DEFAULT});
|
||||
builder.SetInputsDeviceType({input_type, kNumberTypeInt64});
|
||||
builder.SetOutputsFormat({output_format});
|
||||
builder.SetOutputsDeviceType({output_type});
|
||||
builder.SetKernelType(UNKNOWN_KERNEL_TYPE);
|
||||
|
|
|
@ -24,14 +24,6 @@ namespace mindspore::opt {
|
|||
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kFillOpName, 0);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceAllOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceMaxOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceMeanOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceMinOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceSumOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kReduceProdOpName, 1);
|
||||
RER_CPU_DYNAMIC_CONST_TO_ATTR(kTransposeOpName, 1);
|
||||
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kConv2DBackpropInputOpName, 2);
|
||||
|
@ -48,19 +40,7 @@ RER_CPU_STATIC_CONST_TO_ATTR(kCumprodOpName, 1);
|
|||
RER_CPU_STATIC_CONST_TO_ATTR(kFillOpName, 0);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kMeanGradOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kParallelResizeBilinearGradOpName, 2);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kReduceAllOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kReduceMaxOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kReduceMeanOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kReduceMinOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kReduceProdOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kReduceSumOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kReshapeOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kSliceOpName, 1, 2);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kSparseApplyAdagradOpName, 2);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kStridedSliceOpName, 1, 2, 3);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);
|
||||
RER_CPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
|
||||
} // namespace mindspore::opt
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_OPTIMIZER_REG_CPU_CONST_INPUT_TO_ATTR_H_
|
||||
|
|
|
@ -48,72 +48,51 @@ const std::map<std::string, cudnnReduceTensorOp_t> kReduceTypeMap = {
|
|||
{"ReduceProd", CUDNN_REDUCE_TENSOR_MUL},
|
||||
};
|
||||
|
||||
#define STATIC_REGISTER(INPUTX, T) \
|
||||
KernelAttr().AddInputAttr(INPUTX).AddOutputAttr(INPUTX), &ArrayReduceGpuKernelMod::LaunchKernel<T>
|
||||
|
||||
#define DYN_REGISTER(INPUTX, AXIS, T) \
|
||||
#define REDUCE_REGISTER(INPUTX, AXIS, T) \
|
||||
KernelAttr().AddInputAttr(INPUTX).AddInputAttr(AXIS).AddOutputAttr(INPUTX), &ArrayReduceGpuKernelMod::LaunchKernel<T>
|
||||
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::all_any_list_ = {
|
||||
{STATIC_REGISTER(kNumberTypeBool, bool)},
|
||||
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
|
||||
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}};
|
||||
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
|
||||
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)}};
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::prod_list_ = {
|
||||
{STATIC_REGISTER(kNumberTypeFloat64, double)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat32, float)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat16, half)},
|
||||
{STATIC_REGISTER(kNumberTypeInt8, int8_t)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
|
||||
{DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt32, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeInt8, kNumberTypeInt64, int8_t)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
};
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::sum_list_ = {
|
||||
{STATIC_REGISTER(kNumberTypeFloat64, double)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat32, float)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat16, half)},
|
||||
{STATIC_REGISTER(kNumberTypeBool, bool)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
|
||||
{DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt32, bool)},
|
||||
{REDUCE_REGISTER(kNumberTypeBool, kNumberTypeInt64, bool)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
};
|
||||
std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>> ArrayReduceGpuKernelMod::max_min_mean_list_ = {
|
||||
{STATIC_REGISTER(kNumberTypeFloat64, double)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat32, float)},
|
||||
{STATIC_REGISTER(kNumberTypeFloat16, half)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex64, Complex<float>)},
|
||||
{STATIC_REGISTER(kNumberTypeComplex128, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, half)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, float)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, double)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, Complex<float>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, Complex<double>)},
|
||||
{REDUCE_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, Complex<double>)},
|
||||
};
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, ArrayReduceGpuKernelMod::ReduceFunc>>>
|
||||
ArrayReduceGpuKernelMod::kernel_attr_list_ = {
|
||||
|
@ -171,6 +150,10 @@ bool ArrayReduceGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
|
|||
data_type_ = GetCudnnDataType(type_name);
|
||||
}
|
||||
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
|
||||
keep_dims_ = kernel_ptr->get_keep_dims();
|
||||
skip_mode_ = kernel_ptr->get_skip_mode();
|
||||
|
||||
InitResource();
|
||||
return true;
|
||||
}
|
||||
|
@ -202,27 +185,17 @@ int ArrayReduceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
return ret;
|
||||
}
|
||||
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
|
||||
auto inputA_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
|
||||
std::vector<int64_t> attr_axis;
|
||||
if (inputs.size() == kDynamicAxisInputNum) {
|
||||
if (AnfAlgo::IsDynamicShapeSkipExecute(kernel_name_, inputs[kIndex1]->GetShapeVector())) {
|
||||
if (!TryGetIntValue(inputs, kIndex1, kernel_name_, &attr_axis)) {
|
||||
MS_LOG(EXCEPTION) << "For " << kernel_name_ << " can't get axis input! ";
|
||||
}
|
||||
if (AnfAlgo::IsDynamicShapeSkipExecute(skip_mode_, inputs[kIndex1]->GetShapeVector())) {
|
||||
need_skip_execute_ = true;
|
||||
// As size of input_size_list_ is equal to size of inputs, input_size_list_[0] is safe.
|
||||
input_size_ = input_size_list_[0];
|
||||
return KRET_OK;
|
||||
}
|
||||
if (!TryGetIntValue(inputs, kIndex1, kernel_name_, &attr_axis)) {
|
||||
InitCudnnResource();
|
||||
return KRET_OK;
|
||||
}
|
||||
} else {
|
||||
if (kernel_ptr->HasAttr(kAttrAxis)) {
|
||||
attr_axis = kernel_ptr->get_axis();
|
||||
}
|
||||
}
|
||||
keep_dims_ = kernel_ptr->get_keep_dims();
|
||||
|
||||
int input_dim_length = SizeToInt(inputA_shape.size());
|
||||
axis_.clear();
|
||||
|
|
|
@ -68,6 +68,7 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
|
|||
outputC_descriptor_ = nullptr;
|
||||
need_skip_execute_ = false;
|
||||
keep_dims_ = false;
|
||||
skip_mode_ = false;
|
||||
all_match_ = false;
|
||||
is_null_input_ = false;
|
||||
complex_op_type = 0;
|
||||
|
@ -120,6 +121,7 @@ class ArrayReduceGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
std::vector<int> axis_;
|
||||
bool keep_dims_;
|
||||
bool skip_mode_;
|
||||
bool need_skip_execute_;
|
||||
bool all_match_;
|
||||
bool is_null_input_;
|
||||
|
|
|
@ -31,9 +31,6 @@ bool OneHotGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::v
|
|||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input num should be 3 or 4, but get: " << inputs.size();
|
||||
return false;
|
||||
}
|
||||
if (inputs.size() == max_input_num) {
|
||||
is_dynamic_shape_ = true;
|
||||
}
|
||||
constexpr size_t output_num = 1;
|
||||
kernel_name_ = base_operator->GetPrim()->name();
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
|
||||
|
@ -94,10 +91,8 @@ bool OneHotGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, con
|
|||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
size_t on_value_idx = 1;
|
||||
size_t off_value_idx = 2;
|
||||
if (is_dynamic_shape_) {
|
||||
on_value_idx++;
|
||||
off_value_idx++;
|
||||
}
|
||||
const S *indices = GetDeviceAddress<S>(inputs, 0);
|
||||
const T *on_value = GetDeviceAddress<T>(inputs, on_value_idx);
|
||||
const T *off_value = GetDeviceAddress<T>(inputs, off_value_idx);
|
||||
|
|
|
@ -55,7 +55,6 @@ class OneHotGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
static std::vector<std::pair<KernelAttr, OneHotLaunchFunc>> func_list_;
|
||||
OneHotLaunchFunc kernel_func_;
|
||||
bool is_dynamic_shape_ = false;
|
||||
size_t depth_{0};
|
||||
size_t left_dim_size_{1};
|
||||
size_t right_dim_size_{1};
|
||||
|
|
|
@ -26,20 +26,7 @@ std::unique_ptr<cukernel::GpuKernelHelperBase> CreateSliceKernelPtr(const std::s
|
|||
using SlicePtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, SlicePtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateSliceKernelPtr<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateSliceKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateSliceKernelPtr<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateSliceKernelPtr<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateSliceKernelPtr<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateSliceKernelPtr<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateSliceKernelPtr<char>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CreateSliceKernelPtr<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CreateSliceKernelPtr<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CreateSliceKernelPtr<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateSliceKernelPtr<uchar>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateSliceKernelPtr<bool>},
|
||||
{KernelAttr()
|
||||
const std::vector<std::pair<KernelAttr, SlicePtrCreatorFunc>> kernel_attr = {{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
|
|
|
@ -34,13 +34,7 @@ class StridedSliceGpuCommon {
|
|||
StridedSliceGpuCommon() : null_output_(false) {}
|
||||
~StridedSliceGpuCommon() = default;
|
||||
|
||||
void CollectInfo(const BaseOperatorPtr &base_operator, bool is_dynamic_attr_ = false) {
|
||||
if (!is_dynamic_attr_) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::StridedSlice>(base_operator);
|
||||
begin_ = kernel_ptr->get_begin();
|
||||
end_ = kernel_ptr->get_end();
|
||||
strides_ = kernel_ptr->get_strides();
|
||||
}
|
||||
void CollectInfo(const BaseOperatorPtr &base_operator) {
|
||||
auto shape_tmp = Convert2Long(input_shape_);
|
||||
FillEmptyDims(base_operator, &begin_, &end_, &strides_, &shape_tmp);
|
||||
input_shape_ = Convert2SizeT(shape_tmp);
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t DynamicInputNum = 4;
|
||||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
|
@ -57,10 +56,6 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
return ret;
|
||||
}
|
||||
|
||||
if (inputs.size() == DynamicInputNum) {
|
||||
is_dynamic_attr_ = true;
|
||||
}
|
||||
|
||||
auto shape_signed = inputs[0]->GetShapeVector();
|
||||
input_shape_ = Convert2SizeTClipNeg(shape_signed);
|
||||
null_output_ = CHECK_SHAPE_NULL(input_shape_, kernel_name_, "input");
|
||||
|
@ -71,12 +66,11 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << MAX_DIMS
|
||||
<< ", but got " << input_shape_.size();
|
||||
}
|
||||
if (is_dynamic_attr_) {
|
||||
|
||||
TryGetIntValue(inputs, kBeginIndex_, kernel_name_, &begin_);
|
||||
TryGetIntValue(inputs, kEndIndex_, kernel_name_, &end_);
|
||||
TryGetIntValue(inputs, kStrideIndex_, kernel_name_, &strides_);
|
||||
}
|
||||
CollectInfo(base_operator, is_dynamic_attr_);
|
||||
CollectInfo(base_operator);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
@ -94,21 +88,6 @@ int StridedSliceGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const
|
|||
&StridedSliceGpuKernelMod::LaunchKernel<TYPE_1, TYPE_2>
|
||||
|
||||
std::vector<std::pair<KernelAttr, StridedSliceGpuKernelMod::StridedSliceFunc>> StridedSliceGpuKernelMod::func_list_ = {
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat64, double)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat32, float)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeFloat16, half)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeInt64, int64_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeInt32, int32_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeInt16, int16_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeInt8, int8_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt64, uint64_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt32, uint32_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt16, uint16_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeUInt8, uint8_t)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeBool, bool)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeComplex64, Complex<float>)},
|
||||
{STRIDEDSLICE_GPU_REG(kNumberTypeComplex128, Complex<double>)},
|
||||
|
||||
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat64, kNumberTypeInt64, double, int64_t)},
|
||||
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat32, kNumberTypeInt64, float, int64_t)},
|
||||
{STRIDEDSLICE_DYNAMIC_GPU_REG(kNumberTypeFloat16, kNumberTypeInt64, half, int64_t)},
|
||||
|
|
|
@ -59,8 +59,6 @@ class StridedSliceGpuKernelMod : public NativeGpuKernelMod, public StridedSliceG
|
|||
StridedSliceFunc kernel_func_;
|
||||
|
||||
bool is_null_input_{false};
|
||||
bool is_dynamic_attr_{false};
|
||||
bool get_dynamic_attr_value_{false};
|
||||
static constexpr size_t kBeginIndex_{1};
|
||||
static constexpr size_t kEndIndex_{2};
|
||||
static constexpr size_t kStrideIndex_{3};
|
||||
|
|
|
@ -18,38 +18,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
TensorCopySlicesGpuKernelMod, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TensorCopySlicesGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
TensorCopySlicesGpuKernelMod, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
TensorCopySlicesGpuKernelMod, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
TensorCopySlicesGpuKernelMod, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
TensorCopySlicesGpuKernelMod, char)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
TensorCopySlicesGpuKernelMod, uchar)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
TensorCopySlices,
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
TensorCopySlicesGpuKernelMod, bool)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorCopySlices,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
|
|
|
@ -70,9 +70,6 @@ class TensorCopySlicesGpuKernelMod : public NativeGpuKernelMod {
|
|||
return ret;
|
||||
}
|
||||
ResetResource();
|
||||
if (inputs.size() == DynamicInputNum) {
|
||||
is_dynamic_attr_ = true;
|
||||
}
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), OutputNum, kernel_name_);
|
||||
auto shape_signed = inputs.at(kIndex0)->GetShapeVector();
|
||||
input_shape_ = Convert2SizeTClipNeg(shape_signed);
|
||||
|
@ -87,17 +84,9 @@ class TensorCopySlicesGpuKernelMod : public NativeGpuKernelMod {
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << kMaxDims
|
||||
<< ", but got " << input_shape_.size();
|
||||
}
|
||||
if (is_dynamic_attr_) {
|
||||
TryGetIntValue(inputs, kBeginIndex_, kernel_name_, &begin_, false);
|
||||
TryGetIntValue(inputs, kEndIndex_, kernel_name_, &end_, false);
|
||||
TryGetIntValue(inputs, kStrideIndex_, kernel_name_, &strides_, false);
|
||||
} else {
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
begin_ = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrBegin));
|
||||
end_ = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrEnd));
|
||||
strides_ = GetValue<std::vector<int64_t>>(prim->GetAttr(kAttrStrides));
|
||||
}
|
||||
if (begin_.size() > input_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the size of 'begin' cannot be greater than the dimension of input, but got the "
|
||||
|
@ -118,7 +107,6 @@ class TensorCopySlicesGpuKernelMod : public NativeGpuKernelMod {
|
|||
output_size_ = 1;
|
||||
update_size_ = 1;
|
||||
is_null_input_ = false;
|
||||
is_dynamic_attr_ = false;
|
||||
input_shape_.clear();
|
||||
output_shape_.clear();
|
||||
update_shape_.clear();
|
||||
|
@ -224,7 +212,7 @@ class TensorCopySlicesGpuKernelMod : public NativeGpuKernelMod {
|
|||
size_t output_size_;
|
||||
inline static size_t kMaxDims = 8;
|
||||
bool is_null_input_;
|
||||
bool is_dynamic_attr_{false};
|
||||
std::string kernel_name_;
|
||||
static constexpr size_t kBeginIndex_{2};
|
||||
static constexpr size_t kEndIndex_{3};
|
||||
static constexpr size_t kStrideIndex_{4};
|
||||
|
|
|
@ -22,8 +22,7 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kStaticInputNum = 1;
|
||||
constexpr size_t kDynInputNum = 2;
|
||||
constexpr size_t kInputNum = 2;
|
||||
constexpr size_t kTileOutputsNum = 1;
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
|
@ -35,12 +34,8 @@ bool TileGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec
|
|||
MS_EXCEPTION_IF_NULL(prim);
|
||||
kernel_name_ = base_operator->name();
|
||||
size_t input_num = inputs.size();
|
||||
if (input_num == kStaticInputNum) {
|
||||
is_dynamic_case_ = false;
|
||||
} else if (input_num == kDynInputNum) {
|
||||
is_dynamic_case_ = true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 1 or 2, but got " << input_num;
|
||||
if (input_num != kInputNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 2, but got " << input_num;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -83,16 +78,7 @@ int TileGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::ve
|
|||
}
|
||||
shape_size_ = output_shape_.size();
|
||||
output_size_ = SizeOf(output_shape_);
|
||||
if (!is_dynamic_case_) {
|
||||
const std::string kAttrMultiples = "multiples";
|
||||
auto multi_attr = base_operator->GetPrim()->GetAttr(kAttrMultiples);
|
||||
if (multi_attr == nullptr) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
multiples = GetValue<std::vector<int64_t>>(multi_attr);
|
||||
} else {
|
||||
TryGetIntValue(inputs, kIndex1, kernel_name_, &multiples);
|
||||
}
|
||||
int64_t filling_value = static_cast<int64_t>(multiples.size()) - static_cast<int64_t>(input_shape_.size());
|
||||
(void)input_shape_.insert(input_shape_.begin(), filling_value, kIndex1);
|
||||
workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t));
|
||||
|
@ -127,24 +113,6 @@ template <typename T>
|
|||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
std::vector<std::pair<KernelAttr, TileGpuKernelMod::TileLaunchFunc>> TileGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<float>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<double>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&TileGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&TileGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&TileGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
&TileGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), &TileGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&TileGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), &TileGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), &TileGpuKernelMod::LaunchKernel<bool>},
|
||||
// For dynamic shape case:
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&TileGpuKernelMod::LaunchKernel<Complex<float>>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128),
|
||||
|
|
|
@ -50,7 +50,6 @@ class TileGpuKernelMod : public NativeGpuKernelMod {
|
|||
output_size_ = 1;
|
||||
shape_size_ = 1;
|
||||
is_null_input_ = false;
|
||||
is_dynamic_case_ = false;
|
||||
input_shape_.clear();
|
||||
output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
|
@ -68,7 +67,6 @@ class TileGpuKernelMod : public NativeGpuKernelMod {
|
|||
ShapeVector input_shape_;
|
||||
ShapeVector output_shape_;
|
||||
bool is_null_input_;
|
||||
bool is_dynamic_case_;
|
||||
std::vector<int64_t> multiples;
|
||||
using TileLaunchFunc =
|
||||
std::function<bool(TileGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
constexpr size_t kDynamicPermInputNum = 2;
|
||||
constexpr size_t kPermInputNum = 2;
|
||||
constexpr size_t kDimSize4 = 4;
|
||||
constexpr size_t kAxisZero = 0;
|
||||
constexpr size_t kAxis1st = 1;
|
||||
|
@ -35,10 +35,7 @@ constexpr size_t kAxisIndex1st = 1;
|
|||
constexpr size_t kAxisIndex2nd = 2;
|
||||
constexpr size_t kAxisIndex3rd = 3;
|
||||
|
||||
#define STATIC_REGISTER(INPUTX, OUTPUT, T) \
|
||||
{ KernelAttr().AddInputAttr(INPUTX).AddOutputAttr(OUTPUT), &TransposeGpuKernelMod::LaunchKernel<T> }
|
||||
|
||||
#define DYN_REGISTER(INPUTX, PERM, OUTPUT, T) \
|
||||
#define OP_REGISTER(INPUTX, PERM, OUTPUT, T) \
|
||||
{ \
|
||||
KernelAttr().AddInputAttr(INPUTX).AddInputAttr(PERM).AddOutputAttr(OUTPUT), \
|
||||
&TransposeGpuKernelMod::LaunchKernel<T> \
|
||||
|
@ -47,48 +44,34 @@ constexpr size_t kAxisIndex3rd = 3;
|
|||
const std::vector<std::pair<KernelAttr, TransposeGpuKernelMod::KernelRunFunc>> &TransposeGpuKernelMod::GetFuncList()
|
||||
const {
|
||||
static const std::vector<std::pair<KernelAttr, TransposeGpuKernelMod::KernelRunFunc>> func_list = {
|
||||
STATIC_REGISTER(kNumberTypeComplex64, kNumberTypeComplex64, Complex<float>),
|
||||
STATIC_REGISTER(kNumberTypeComplex128, kNumberTypeComplex128, Complex<double>),
|
||||
STATIC_REGISTER(kNumberTypeBool, kNumberTypeBool, bool),
|
||||
STATIC_REGISTER(kNumberTypeFloat64, kNumberTypeFloat64, double),
|
||||
STATIC_REGISTER(kNumberTypeFloat32, kNumberTypeFloat32, float),
|
||||
STATIC_REGISTER(kNumberTypeFloat16, kNumberTypeFloat16, half),
|
||||
STATIC_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t),
|
||||
STATIC_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t),
|
||||
STATIC_REGISTER(kNumberTypeInt16, kNumberTypeInt16, int16_t),
|
||||
STATIC_REGISTER(kNumberTypeInt8, kNumberTypeInt8, int8_t),
|
||||
STATIC_REGISTER(kNumberTypeUInt8, kNumberTypeUInt8, uint8_t),
|
||||
STATIC_REGISTER(kNumberTypeUInt16, kNumberTypeUInt16, uint16_t),
|
||||
STATIC_REGISTER(kNumberTypeUInt32, kNumberTypeUInt32, uint32_t),
|
||||
STATIC_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, uint64_t),
|
||||
DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, kNumberTypeComplex64, Complex<float>),
|
||||
DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, kNumberTypeComplex128, Complex<double>),
|
||||
DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, kNumberTypeBool, bool),
|
||||
DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat64, double),
|
||||
DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, float),
|
||||
DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, kNumberTypeFloat16, half),
|
||||
DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t),
|
||||
DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t),
|
||||
DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt16, int16_t),
|
||||
DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt8, int8_t),
|
||||
DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, uint8_t),
|
||||
DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeUInt16, uint16_t),
|
||||
DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeUInt32, uint32_t),
|
||||
DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeUInt64, uint64_t),
|
||||
DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, kNumberTypeComplex64, Complex<float>),
|
||||
DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, kNumberTypeComplex128, Complex<double>),
|
||||
DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, kNumberTypeBool, bool),
|
||||
DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeFloat64, double),
|
||||
DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeFloat32, float),
|
||||
DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, kNumberTypeFloat16, half),
|
||||
DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
|
||||
DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, int32_t),
|
||||
DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt16, int16_t),
|
||||
DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt8, int8_t),
|
||||
DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, uint8_t),
|
||||
DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, uint16_t),
|
||||
DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, uint32_t),
|
||||
DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t),
|
||||
OP_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, kNumberTypeComplex64, Complex<float>),
|
||||
OP_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, kNumberTypeComplex128, Complex<double>),
|
||||
OP_REGISTER(kNumberTypeBool, kNumberTypeInt32, kNumberTypeBool, bool),
|
||||
OP_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat64, double),
|
||||
OP_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, float),
|
||||
OP_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, kNumberTypeFloat16, half),
|
||||
OP_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t),
|
||||
OP_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t),
|
||||
OP_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt16, int16_t),
|
||||
OP_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt8, int8_t),
|
||||
OP_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, uint8_t),
|
||||
OP_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeUInt16, uint16_t),
|
||||
OP_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeUInt32, uint32_t),
|
||||
OP_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeUInt64, uint64_t),
|
||||
OP_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, kNumberTypeComplex64, Complex<float>),
|
||||
OP_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, kNumberTypeComplex128, Complex<double>),
|
||||
OP_REGISTER(kNumberTypeBool, kNumberTypeInt64, kNumberTypeBool, bool),
|
||||
OP_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeFloat64, double),
|
||||
OP_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeFloat32, float),
|
||||
OP_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, kNumberTypeFloat16, half),
|
||||
OP_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
|
||||
OP_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, int32_t),
|
||||
OP_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt16, int16_t),
|
||||
OP_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt8, int8_t),
|
||||
OP_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, uint8_t),
|
||||
OP_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, uint16_t),
|
||||
OP_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, uint32_t),
|
||||
OP_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t),
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
|
@ -102,10 +85,6 @@ bool TransposeGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
|
||||
|
||||
if (is_dynamic_perm_ && !get_dynamic_perm_value_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to get value of the dynamic perm!";
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
|
@ -155,26 +134,13 @@ bool TransposeGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
|
|||
size_t input_num = inputs.size();
|
||||
size_t output_num = outputs.size();
|
||||
kernel_name_ = base_operator->name();
|
||||
if (input_num != 1 && input_num != kDynamicPermInputNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1 or " << kDynamicPermInputNum
|
||||
<< ", but got " << input_num;
|
||||
if (input_num != kPermInputNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << kPermInputNum << ", but got "
|
||||
<< input_num;
|
||||
}
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
|
||||
}
|
||||
if (input_num == kDynamicPermInputNum) {
|
||||
is_dynamic_perm_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Transpose>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "For primitive[Transpose], cast op from BaseOperator to Transpose failed.";
|
||||
return false;
|
||||
}
|
||||
std::vector<int64_t> perm;
|
||||
perm = kernel_ptr->get_perm();
|
||||
GetPermValue(perm);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -120,16 +120,9 @@ std::vector<KernelAttr> MemcpyGpuKernelMod::GetOpSupport() {
|
|||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex128)};
|
||||
|
||||
std::vector<KernelAttr> reshape_valid_types;
|
||||
reshape_valid_types.insert(reshape_valid_types.end(), common_valid_types_with_single_input.begin(),
|
||||
common_valid_types_with_single_input.end());
|
||||
reshape_valid_types.insert(reshape_valid_types.end(), common_valid_types_with_double_input.begin(),
|
||||
common_valid_types_with_double_input.end());
|
||||
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
|
||||
{kReshape, reshape_valid_types},
|
||||
{kFlatten, common_valid_types_with_single_input},
|
||||
{kFlattenGrad, common_valid_types_with_single_input},
|
||||
{kExpandDims, common_valid_types_with_double_input},
|
||||
{kReshape, common_valid_types_with_double_input}, {kFlatten, common_valid_types_with_single_input},
|
||||
{kFlattenGrad, common_valid_types_with_single_input}, {kExpandDims, common_valid_types_with_double_input},
|
||||
{kSqueeze, common_valid_types_with_single_input},
|
||||
};
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
|||
(void)node_inputs.insert(node_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
CNodePtr new_cnode = func_graph->NewCNode(node_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto predict_input = cnode->inputs()[1];
|
||||
auto new_node_dtype = {common::AnfAlgo::GetOutputInferDataType(predict_input, 0)};
|
||||
auto new_node_shape = {common::AnfAlgo::GetOutputDetailShape(predict_input, 0)};
|
||||
|
@ -46,10 +48,20 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
|||
string reduction = common::AnfAlgo::GetNodeAttr<std::string>(node, kAttrReduction);
|
||||
MS_LOG(INFO) << "Create reduce node for BCEWithLogitsLoss, reduction attr is: " << reduction;
|
||||
std::vector<AnfNodePtr> reduce_inputs;
|
||||
std::vector<int64_t> axis_shp = {0};
|
||||
auto axis_tensor = std::make_shared<tensor::Tensor>(kInt64->type_id(), axis_shp);
|
||||
MS_EXCEPTION_IF_NULL(axis_tensor);
|
||||
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, kInt64};
|
||||
axis_tensor->set_device_info(device_info);
|
||||
ValueNodePtr axis_node = std::make_shared<ValueNode>(axis_tensor);
|
||||
MS_EXCEPTION_IF_NULL(axis_node);
|
||||
axis_node->set_abstract(axis_tensor->ToAbstract());
|
||||
axis_node = kernel_graph->NewValueNode(axis_node);
|
||||
kernel_graph->AddValueNode(axis_node);
|
||||
if (reduction == "sum") {
|
||||
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode};
|
||||
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceSum->name())), new_cnode, axis_node};
|
||||
} else if (reduction == "mean") {
|
||||
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode};
|
||||
reduce_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimReduceMean->name())), new_cnode, axis_node};
|
||||
} else {
|
||||
MS_LOG(INFO) << "Reduction is none, no optimization on current BCEWithLogitsLoss.";
|
||||
return nullptr;
|
||||
|
@ -59,7 +71,6 @@ AnfNodePtr AddReduceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node)
|
|||
auto type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
auto shape = {common::AnfAlgo::GetOutputDetailShape(node, 0)};
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape({type}, shape, reduce_node.get());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(std::vector<int64_t>{}), reduce_node);
|
||||
common::AnfAlgo::SetNodeAttr("keep_dims", MakeValue(false), reduce_node);
|
||||
reduce_node->set_scope(cnode->scope());
|
||||
return reduce_node;
|
||||
|
|
|
@ -31,6 +31,23 @@ namespace opt {
|
|||
namespace {
|
||||
const int kFakeTransposeShapeOneNum = 2;
|
||||
|
||||
AnfNodePtr ConvertValueToTensor(const KernelGraphPtr &kernel_graph, const ValueNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
tensor::TensorPtr tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
|
||||
MS_EXCEPTION_IF_NULL(tensor_input);
|
||||
tensor_input->set_abstract(tensor_ptr->ToAbstract());
|
||||
tensor_input = kernel_graph->NewValueNode(tensor_input);
|
||||
kernel_graph->AddValueNodeToGraph(tensor_input);
|
||||
tensor_input->set_scope(input_node->scope());
|
||||
return tensor_input;
|
||||
}
|
||||
|
||||
std::vector<int64_t> TransposeAxis(const std::string &src_format, const std::string &dst_format) {
|
||||
if ((src_format == kOpFormat_NCHW) && (dst_format == kOpFormat_NHWC)) {
|
||||
return {0, 2, 3, 1};
|
||||
|
@ -66,8 +83,8 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string
|
|||
auto input_type = common::AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
auto output_type = common::AnfAlgo::GetOutputInferDataType(node, 0);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({input_format});
|
||||
builder.SetInputsDeviceType({input_type});
|
||||
builder.SetInputsFormat({input_format, kOpFormat_DEFAULT});
|
||||
builder.SetInputsDeviceType({input_type, kNumberTypeInt64});
|
||||
builder.SetOutputsFormat({output_format});
|
||||
builder.SetOutputsDeviceType({output_type});
|
||||
builder.SetKernelType(UNKNOWN_KERNEL_TYPE);
|
||||
|
@ -81,6 +98,8 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(used_node);
|
||||
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
||||
<< ", index: " << used_node_index;
|
||||
|
@ -97,18 +116,23 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co
|
|||
MS_EXCEPTION_IF_NULL(transpose_prim);
|
||||
// 2.Set the input of transpose.
|
||||
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
|
||||
if (is_fake) {
|
||||
auto reshape_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index);
|
||||
auto shape_node = NewValueNode(MakeValue<std::vector<int64_t>>(reshape_shape));
|
||||
auto shape_tensor = ConvertValueToTensor(kernel_graph, shape_node);
|
||||
transpose_input.push_back(shape_tensor);
|
||||
} else {
|
||||
auto perm_node = NewValueNode(MakeValue<std::vector<int64_t>>(transpose_perm));
|
||||
auto perm_tensor = ConvertValueToTensor(kernel_graph, perm_node);
|
||||
transpose_input.push_back(perm_tensor);
|
||||
}
|
||||
auto transpose_op = graph->NewCNode(transpose_input);
|
||||
MS_EXCEPTION_IF_NULL(transpose_op);
|
||||
// 3.Set the output info of transpose.
|
||||
auto transpose_type = {common::AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
|
||||
auto transpose_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index);
|
||||
auto base_shape = common::AnfAlgo::GetPrevNodeOutputDetailShape(used_node, used_node_index);
|
||||
common::AnfAlgo::SetOutputTypeAndDetailShape(transpose_type, {base_shape}, transpose_op.get());
|
||||
if (is_fake) {
|
||||
common::AnfAlgo::SetNodeAttr("shape", MakeValue(transpose_shape), transpose_op);
|
||||
} else {
|
||||
common::AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
|
||||
}
|
||||
|
||||
// 4. Set the new edge of transpose op.
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "plugin/device/gpu/hal/device/kernel_info_setter.h"
|
||||
|
@ -124,6 +125,23 @@ void CalSplitAttrs(SplitvNodeInfo *splitv_node_info) {
|
|||
splitv_node_info->num_split = num_split;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertValueToTensor(const KernelGraphPtr &kernel_graph, const ValueNodePtr &input_node) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
tensor::TensorPtr tensor_ptr = CreateTupleTensor(value->cast<ValueTuplePtr>());
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
auto tensor_input = std::make_shared<ValueNode>(tensor_ptr);
|
||||
MS_EXCEPTION_IF_NULL(tensor_input);
|
||||
tensor_input->set_abstract(tensor_ptr->ToAbstract());
|
||||
tensor_input = kernel_graph->NewValueNode(tensor_input);
|
||||
kernel_graph->AddValueNodeToGraph(tensor_input);
|
||||
tensor_input->set_scope(input_node->scope());
|
||||
return tensor_input;
|
||||
}
|
||||
|
||||
CNodePtr CreateSliceNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &slice_input,
|
||||
const SliceNodeInfo slice_node_info, const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -131,20 +149,13 @@ CNodePtr CreateSliceNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr
|
|||
MS_LOG(EXCEPTION) << "The input is empty, can not create slice node.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto slice = pass.NewCNode(slice_input, graph);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
ShapeVector slice_shape(slice_node_info.base_shape);
|
||||
std::vector<int64_t> begins(slice_shape.size(), 0);
|
||||
|
||||
slice_shape[slice_node_info.slice_dim] = slice_node_info.slice_size;
|
||||
begins[slice_node_info.slice_dim] = slice_node_info.slice_begin;
|
||||
|
||||
std::vector<ShapeVector> shapes = {slice_shape};
|
||||
std::vector<TypeId> dtypes(1, slice_node_info.input_dtype);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, slice.get());
|
||||
common::AnfAlgo::SetNodeAttr(kAttrBegin, MakeValue(begins), slice);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(slice_shape), slice);
|
||||
return slice;
|
||||
}
|
||||
|
||||
|
@ -152,6 +163,8 @@ CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const AnfNodePtr &split_inpu
|
|||
const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(splitv_node_info);
|
||||
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
CalSplitAttrs(splitv_node_info);
|
||||
|
||||
|
@ -168,7 +181,18 @@ CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const AnfNodePtr &split_inpu
|
|||
|
||||
for (int64_t idx = 0; idx < splitv_node_info->num_split; ++idx) {
|
||||
slice_node_info.slice_size = splitv_node_info->size_splits[idx];
|
||||
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), split_input};
|
||||
ShapeVector slice_shape(slice_node_info.base_shape);
|
||||
slice_shape[slice_node_info.slice_dim] = slice_node_info.slice_size;
|
||||
// begins
|
||||
std::vector<int64_t> begins(slice_shape.size(), 0);
|
||||
begins[slice_node_info.slice_dim] = slice_node_info.slice_begin;
|
||||
auto beigns_node = NewValueNode(MakeValue<std::vector<int64_t>>(begins));
|
||||
auto beigns_tensor = ConvertValueToTensor(kernel_graph, beigns_node);
|
||||
// slice_shape
|
||||
auto slice_shape_node = NewValueNode(MakeValue<std::vector<int64_t>>(slice_shape));
|
||||
auto slice_shape_tensor = ConvertValueToTensor(kernel_graph, slice_shape_node);
|
||||
std::vector<AnfNodePtr> slice_inputs = {NewValueNode(std::make_shared<Primitive>(kSliceOpName)), split_input,
|
||||
beigns_tensor, slice_shape_tensor};
|
||||
slice = CreateSliceNode(graph, slice_inputs, slice_node_info, pass);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
AddNewOutputs(graph, slice, 1, &make_tuple_inputs);
|
||||
|
|
|
@ -24,16 +24,6 @@ namespace mindspore::opt {
|
|||
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kFillOpName, 0);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceAllOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceMaxOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceMeanOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceMinOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceSumOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReduceProdOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kReshapeOpName, 1);
|
||||
RER_GPU_DYNAMIC_CONST_TO_ATTR(kTransposeOpName, 1);
|
||||
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kCastOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kConv2DBackpropFilterOpName, 2);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kConv2DBackpropInputOpName, 2);
|
||||
|
@ -49,24 +39,11 @@ RER_GPU_STATIC_CONST_TO_ATTR(kCSRReduceSumOpName, 3, 4);
|
|||
RER_GPU_STATIC_CONST_TO_ATTR(kCumprodOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kFillOpName, 0);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kMeanGradOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kOneHotOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kParallelResizeBilinearGradOpName, 2);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kReduceAllOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kReduceAnyOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kReduceMaxOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kReduceMeanOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kReduceMinOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kReduceProdOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kReduceSumOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kReshapeOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kSpaceToBatchOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kSparseApplyAdagradOpName, 2);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kStridedSliceAssignOpName, 1, 2, 3);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kStridedSliceOpName, 1, 2, 3);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTensorCopySlicesOpName, 2, 3, 4);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTileOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kTransposeOpName, 1);
|
||||
RER_GPU_STATIC_CONST_TO_ATTR(kSliceOpName, 1, 2);
|
||||
} // namespace mindspore::opt
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_GPU_OPTIMIZER_REG_GPU_CONST_INPUT_TO_ATTR_H_
|
||||
|
|
|
@ -26,9 +26,13 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
const BaseRef RemoveFormatTransformPair::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
VarPtr Z = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
VectorRef transpose1 = VectorRef({prim::kPrimTranspose, X});
|
||||
VectorRef transpose2 = VectorRef({prim::kPrimTranspose, transpose1});
|
||||
MS_EXCEPTION_IF_NULL(Y);
|
||||
MS_EXCEPTION_IF_NULL(Z);
|
||||
VectorRef transpose1 = VectorRef({prim::kPrimTranspose, X, Y});
|
||||
VectorRef transpose2 = VectorRef({prim::kPrimTranspose, transpose1, Z});
|
||||
return transpose2;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,8 +27,10 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
const BaseRef RemoveRedundantFormatTransform::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
VectorRef transpose = VectorRef({prim::kPrimTranspose, X});
|
||||
MS_EXCEPTION_IF_NULL(Y);
|
||||
VectorRef transpose = VectorRef({prim::kPrimTranspose, X, Y});
|
||||
return transpose;
|
||||
}
|
||||
|
||||
|
@ -49,8 +51,9 @@ const AnfNodePtr RemoveRedundantFormatTransform::Process(const FuncGraphPtr &gra
|
|||
break;
|
||||
}
|
||||
}
|
||||
auto first_transpose_perm = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(first_transpose, "perm");
|
||||
auto node_perm = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "perm");
|
||||
const int64_t perm_param_idx = 1;
|
||||
auto first_transpose_perm = AnfAlgo::GetInputDeviceShape(first_transpose, perm_param_idx);
|
||||
auto node_perm = AnfAlgo::GetInputDeviceShape(node, perm_param_idx);
|
||||
if ((first_transpose != node) && (first_transpose_perm == node_perm)) {
|
||||
return first_transpose;
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ OperatorRegister &OperatorRegister::GetInstance() {
|
|||
return instance;
|
||||
}
|
||||
|
||||
std::map<std::string, OperatorDefineFunc> OperatorRegister::GetOperatorMap() { return operator_fns_; }
|
||||
const std::map<std::string, OperatorDefineFunc> &OperatorRegister::GetOperatorMap() { return operator_fns_; }
|
||||
void OperatorRegister::SetOperatorMap(const std::string &kname, const OperatorDefineFunc &fn) {
|
||||
operator_fns_[kname] = fn;
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ class MIND_API OperatorRegister {
|
|||
|
||||
static OperatorRegister &GetInstance();
|
||||
|
||||
std::map<std::string, OperatorDefineFunc> GetOperatorMap();
|
||||
const std::map<std::string, OperatorDefineFunc> &GetOperatorMap();
|
||||
|
||||
void SetOperatorMap(const std::string &kname, const OperatorDefineFunc &fn);
|
||||
|
||||
|
|
|
@ -116,6 +116,7 @@ constexpr auto kIoFormat = "io_format";
|
|||
constexpr auto kIsScale = "is_scale";
|
||||
constexpr auto kIsTraining = "is_training";
|
||||
constexpr auto kKeepDims = "keep_dims";
|
||||
constexpr auto kSkipMode = "skip_mode";
|
||||
constexpr auto kKeepProb = "keep_prob";
|
||||
constexpr auto kPadDimSize = "pad_dim_size";
|
||||
constexpr auto kUnbiased = "unbiased";
|
||||
|
|
|
@ -174,11 +174,17 @@ bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_ar
|
|||
input_args[kInputIndex1]->isa<abstract::AbstractTuple>() ||
|
||||
input_args[kInputIndex1]->isa<abstract::AbstractList>()) {
|
||||
*axis_value = CheckAndConvertUtils::CheckIntOrTupleInt("axis", input_value, op_name);
|
||||
if (axis_value->empty()) {
|
||||
*axis_shape_v = 0;
|
||||
}
|
||||
} else if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("axis", input_args[kInputIndex1]->BuildType(), {kInt32, kInt64},
|
||||
op_name);
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
*axis_value = CheckAndConvertUtils::CheckTensorIntValue("axis", input_value, op_name);
|
||||
if (axis_value->empty()) {
|
||||
*axis_shape_v = 0;
|
||||
}
|
||||
} else {
|
||||
is_dynamic = true;
|
||||
auto axis_shape = CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 1);
|
||||
|
@ -199,6 +205,14 @@ bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_ar
|
|||
return is_dynamic;
|
||||
}
|
||||
|
||||
bool IsDynamicShapeSkipExecute(const bool skip_mode, const ShapeVector &axes_shape) {
|
||||
// Skip run ReduceSum when axis is a Empty Tensor
|
||||
if (std::any_of(axes_shape.begin(), axes_shape.end(), [](int64_t shape) { return shape == 0; }) && skip_mode) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args,
|
||||
const std::string &prim_name) {
|
||||
|
@ -206,6 +220,12 @@ abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
|
|||
auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
auto x_shape = shape_ptr->shape();
|
||||
bool skip_mode = false;
|
||||
if (primitive->HasAttr(kSkipMode)) {
|
||||
auto skip_mode_value_ptr = primitive->GetAttr(kSkipMode);
|
||||
MS_EXCEPTION_IF_NULL(skip_mode_value_ptr);
|
||||
skip_mode = GetValue<bool>(skip_mode_value_ptr);
|
||||
}
|
||||
auto keep_dimis_value_ptr = primitive->GetAttr(kKeepDims);
|
||||
MS_EXCEPTION_IF_NULL(keep_dimis_value_ptr);
|
||||
if (!keep_dimis_value_ptr->isa<BoolImm>()) {
|
||||
|
@ -213,8 +233,11 @@ abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
bool keep_dims = GetValue<bool>(keep_dimis_value_ptr);
|
||||
std::vector<int64_t> axis_value;
|
||||
int64_t axis_shape = 0;
|
||||
int64_t axis_shape = 1;
|
||||
bool axis_is_dynamic = CheckAndGetAxisValue(input_args, &axis_value, &axis_shape, primitive);
|
||||
if (IsDynamicShapeSkipExecute(skip_mode, {axis_shape})) {
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
ShapeVector out_shape = {};
|
||||
constexpr int dynamic_rank_value = -2;
|
||||
if (IsDynamicRank(x_shape)) {
|
||||
|
|
|
@ -28,23 +28,13 @@ void Reduce::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims
|
|||
|
||||
bool Reduce::get_keep_dims() const { return GetValue<bool>(GetAttr(kKeepDims)); }
|
||||
|
||||
void Reduce::Init(const bool keep_dims) { this->set_keep_dims(keep_dims); }
|
||||
void Reduce::set_skip_mode(const bool skip_mode) { (void)this->AddAttr(kSkipMode, api::MakeValue(skip_mode)); }
|
||||
|
||||
void Reduce::set_axis(const std::vector<int64_t> &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
bool Reduce::get_skip_mode() const { return GetValue<bool>(GetAttr(kSkipMode)); }
|
||||
|
||||
std::vector<int64_t> Reduce::get_axis() {
|
||||
PrimitivePtr prim = this->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<int64_t> axis = {};
|
||||
if (prim->HasAttr(kAttrAxis)) {
|
||||
auto value_ptr = prim->GetAttr(kAttrAxis);
|
||||
if (value_ptr->isa<tensor::Tensor>()) {
|
||||
axis = CheckAndConvertUtils::CheckTensorIntValue(kAttrAxis, value_ptr, prim->name());
|
||||
} else {
|
||||
axis = CheckAndConvertUtils::CheckIntOrTupleInt(kAttrAxis, value_ptr, prim->name());
|
||||
}
|
||||
}
|
||||
return axis;
|
||||
void Reduce::Init(const bool keep_dims, const bool skip_mode) {
|
||||
this->set_keep_dims(keep_dims);
|
||||
this->set_skip_mode(skip_mode);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Reduce, BaseOperator);
|
||||
|
|
|
@ -39,7 +39,8 @@ class MIND_API Reduce : public BaseOperator {
|
|||
/// \brief Method to init the op's attributes.
|
||||
///
|
||||
/// \param[in] keep_dims Define whether keep the dims reduced, default false.
|
||||
void Init(const bool keep_dims = false);
|
||||
/// \param[in] skip_mode Define whether skip reduce, default false.
|
||||
void Init(const bool keep_dims = false, const bool skip_mode = false);
|
||||
|
||||
/// \brief Method to set keep_dims attribute.
|
||||
///
|
||||
|
@ -51,9 +52,15 @@ class MIND_API Reduce : public BaseOperator {
|
|||
/// \return keep_dims attribute.
|
||||
bool get_keep_dims() const;
|
||||
|
||||
void set_axis(const std::vector<int64_t> &axis);
|
||||
/// \brief Method to set skip_mode attribute.
|
||||
///
|
||||
/// \param[in] skip_mode Define whether skip reduce, default false.
|
||||
void set_skip_mode(const bool skip_mode);
|
||||
|
||||
std::vector<int64_t> get_axis();
|
||||
/// \brief Method to get skip_mode attribute.
|
||||
///
|
||||
/// \return skip_mode attribute.
|
||||
bool get_skip_mode() const;
|
||||
};
|
||||
abstract::AbstractBasePtr ReduceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kSliceInputNum = 3;
|
||||
constexpr int64_t kDynamicOutValue = -2;
|
||||
std::vector<int64_t> InferImplSliceFuncCalInputValue(const PrimitivePtr &primitive, const ValuePtr &input_value) {
|
||||
std::vector<int64_t> tmp_input;
|
||||
|
@ -87,6 +88,7 @@ ShapeVector GetOutputShape(const ShapeVector &input_size_shape, const ShapeVecto
|
|||
abstract::ShapePtr SliceInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_args.size() == kSliceInputNum, "Slice inputs num error");
|
||||
auto input_x_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto input_begin_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape());
|
||||
auto input_size_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -447,6 +447,7 @@ int64_t StridedSlice::get_shrink_axis_mask() const {
|
|||
auto value_ptr = GetAttr(kShrinkAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int64_t> StridedSlice::get_begin() const {
|
||||
auto value_ptr = GetAttr(kAttrBegin);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
|
@ -459,6 +460,7 @@ std::vector<int64_t> StridedSlice::get_strides() const {
|
|||
auto value_ptr = GetAttr(kAttrStrides);
|
||||
return GetValue<std::vector<int64_t>>(value_ptr);
|
||||
}
|
||||
|
||||
void StridedSlice::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
|
||||
int64_t shrink_axis_mask) {
|
||||
this->set_begin_mask(begin_mask);
|
||||
|
|
|
@ -29,8 +29,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
|
||||
|
||||
std::vector<int64_t> Transpose::get_perm() {
|
||||
PrimitivePtr prim = this->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -46,46 +44,31 @@ std::vector<int64_t> Transpose::get_perm() {
|
|||
return perm;
|
||||
}
|
||||
|
||||
bool CheckAndGetPermValue(const std::vector<AbstractBasePtr> &input_args, ShapeVector *perm_value,
|
||||
const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(perm_value);
|
||||
bool is_dynamic = false;
|
||||
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
|
||||
|
||||
ShapeVector CheckAndGetPermValue(const std::vector<AbstractBasePtr> &input_args, const PrimitivePtr &primitive) {
|
||||
const std::string &op_name = primitive->name();
|
||||
|
||||
if (input_args.size() == 1) {
|
||||
if (!primitive->HasAttr("perm")) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of 'input_perm' is necessary, but missing it.";
|
||||
}
|
||||
ValuePtr perm = primitive->GetAttr("perm");
|
||||
MS_EXCEPTION_IF_NULL(perm);
|
||||
if (perm->isa<tensor::Tensor>()) {
|
||||
*perm_value = CheckAndConvertUtils::CheckTensorIntValue("perm", perm, op_name);
|
||||
} else {
|
||||
*perm_value = CheckAndConvertUtils::CheckTupleInt("perm", perm, op_name);
|
||||
}
|
||||
return is_dynamic;
|
||||
}
|
||||
|
||||
const size_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto input_value = input_args[kInputIndex1]->BuildValue();
|
||||
if (input_args[kInputIndex1]->isa<abstract::AbstractTuple>()) {
|
||||
*perm_value = CheckAndConvertUtils::CheckTupleInt("perm", input_value, op_name);
|
||||
return CheckAndConvertUtils::CheckTupleInt("perm", input_value, op_name);
|
||||
} else if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
*perm_value = CheckAndConvertUtils::CheckTensorIntValue("perm", input_value, op_name);
|
||||
} else {
|
||||
is_dynamic = true;
|
||||
return CheckAndConvertUtils::CheckTensorIntValue("perm", input_value, op_name);
|
||||
}
|
||||
auto perm_shape = CheckAndConvertUtils::GetTensorInputShape("perm", input_args, 1);
|
||||
if (perm_shape->shape().size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For 'transpose perm', " << op_name << " must be 1-D, but got"
|
||||
<< perm_shape->shape().size() << "-D.";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(TypeError) << "For primitive[" << op_name
|
||||
<< "], the perm must be a tuple or a tensor with all Int elements, but got "
|
||||
<< input_args[kInputIndex1]->type_name() << ".";
|
||||
}
|
||||
return is_dynamic;
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
class TransposeInfer : public abstract::OpInferBase {
|
||||
|
@ -97,7 +80,6 @@ class TransposeInfer : public abstract::OpInferBase {
|
|||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name);
|
||||
ShapeVector p_value;
|
||||
ShapeVector p_value_raw;
|
||||
if (x_shape[0] == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0.";
|
||||
}
|
||||
|
@ -105,10 +87,10 @@ class TransposeInfer : public abstract::OpInferBase {
|
|||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
|
||||
bool perm_is_dynamic = CheckAndGetPermValue(input_args, &p_value_raw, primitive);
|
||||
if (perm_is_dynamic) {
|
||||
auto p_value_raw = CheckAndGetPermValue(input_args, primitive);
|
||||
if (p_value_raw.empty()) {
|
||||
ShapeVector out_shape;
|
||||
(void)out_shape.insert(out_shape.end(), x_shape.size(), -1);
|
||||
(void)out_shape.insert(out_shape.end(), x_shape.size(), abstract::Shape::kShapeDimAny);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -223,16 +223,15 @@ ExpanderPtr GraphKernelExpanderLite::InitExpander(const AnfNodePtr &node) {
|
|||
auto expander = std::make_shared<DefaultExpander>(Callback::Instance());
|
||||
ExpanderCreatorFuncList decos = {InferValueDeco::Creator};
|
||||
std::map<std::string, ExpanderCreatorFuncList> creators = {
|
||||
{prim::kPrimReduceFusion->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimExpandDims->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimReduceFusion->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
|
||||
{prim::kPrimExpandDims->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
|
||||
{prim::kPrimUnsqueeze->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimSqueeze->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimShape->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimReshape->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimConstantOfShape->name(), {InputToAttrDeco::GetCreator({0}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimTranspose->name(), {TensorToValueDeco::GetCreator({1}), InputToAttrDeco::GetCreator({1})}},
|
||||
{prim::kPrimGather->name(),
|
||||
{TensorToValueDeco::GetCreator({2}), InputToAttrDeco::GetCreator({2}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimReshape->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
|
||||
{prim::kPrimConstantOfShape->name(), {InputToAttrDeco::Creator, FixFormatDeco::Creator}},
|
||||
{prim::kPrimTranspose->name(), {TensorToValueDeco::GetCreator({1}), InputToAttrDeco::Creator}},
|
||||
{prim::kPrimGather->name(), {TensorToValueDeco::GetCreator({2}), InputToAttrDeco::Creator, FixFormatDeco::Creator}},
|
||||
{prim::kPrimConcat->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimStridedSlice->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimConv2DFusion->name(), {SubstituteConv2D::Creator}},
|
||||
|
|
|
@ -149,5 +149,5 @@ def dyn_ones(shape, value_type):
|
|||
|
||||
|
||||
def sum_grad_reduce_axis(x, rx, keep_dims=False):
|
||||
"""sum the grad_reduce_axis"""
|
||||
return P.ReduceSum(keep_dims=keep_dims)(x, rx)
|
||||
"""sum the grad_reduce_axis. when rx is an empty tensor and skip_mode is True, we need skip this reduce"""
|
||||
return P.ReduceSum(keep_dims=keep_dims, skip_mode=True)(x, rx)
|
||||
|
|
|
@ -513,7 +513,6 @@ class _Reduce(PrimitiveWithCheck):
|
|||
value = None
|
||||
if input_x is not None and axis is not None:
|
||||
prim_map = {
|
||||
'ReduceSum': np.sum,
|
||||
'ReduceMax': np.max,
|
||||
'ReduceMin': np.min,
|
||||
'ReduceProd': np.prod,
|
||||
|
@ -737,7 +736,7 @@ class CumulativeLogsumexp(Primitive):
|
|||
validator.check_bool(reverse, "reverse", self.name)
|
||||
|
||||
|
||||
class ReduceSum(_Reduce):
|
||||
class ReduceSum(PrimitiveWithCheck):
|
||||
"""
|
||||
Reduces a dimension of a tensor by summing all elements in the dimension, by default. And also can reduce a
|
||||
dimension of `x` along the axis. Determine whether the dimensions of the output and input are the same by
|
||||
|
@ -746,18 +745,24 @@ class ReduceSum(_Reduce):
|
|||
Args:
|
||||
keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
|
||||
If false, don't keep these dimensions. Default: False.
|
||||
skip_mode (bool): If true and axis is empty tuple or empty list, the ReduceSum operation isn't performed,
|
||||
skip it.
|
||||
If true and axis is other values, the ReduceSum calculation is performed normally.
|
||||
If false, do reduce. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor[Number]) - The input tensor. The dtype of the tensor to be reduced is number.
|
||||
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
|
||||
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions.
|
||||
Only constant value is allowed. Must be in the range [-rank(`x`), rank(`x`)).
|
||||
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Default: (), reduce all dimensions
|
||||
when skip_mode is false. Only constant value is allowed. Must be in the range [-rank(`x`), rank(`x`)).
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same dtype as the `x`.
|
||||
|
||||
- If axis is (), and keep_dims is False,
|
||||
- If axis is (), keep_dims is False, and skip_mode is False,
|
||||
the output is a 0-D tensor representing the sum of all elements in the input tensor.
|
||||
- If axis is (), and skip_mode is True,
|
||||
the ReduceSum operation is not performed, output tensor is equal to the input tensor.
|
||||
- If axis is int, set as 2, and keep_dims is False,
|
||||
the shape of output is :math:`(x_1, x_3, ..., x_R)`.
|
||||
- If axis is tuple(int) or list(int), set as (2, 3), and keep_dims is False,
|
||||
|
@ -765,6 +770,7 @@ class ReduceSum(_Reduce):
|
|||
|
||||
Raises:
|
||||
TypeError: If `keep_dims` is not a bool.
|
||||
TypeError: If `skip_mode` is not a bool.
|
||||
TypeError: If `x` is not a Tensor.
|
||||
ValueError: If `axis` is None.
|
||||
|
||||
|
@ -812,12 +818,44 @@ class ReduceSum(_Reduce):
|
|||
[54.]]]
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x'),
|
||||
sig.make_sig('axis', default=())
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, keep_dims=False):
|
||||
"""Initialize ReduceSum"""
|
||||
super(ReduceSum, self).__init__(keep_dims)
|
||||
def __init__(self, keep_dims=False, skip_mode=False):
|
||||
"""Initialize Reduce"""
|
||||
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
|
||||
validator.check_value_type('skip_mode', skip_mode, [bool], self.name)
|
||||
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
|
||||
self.keep_dims = keep_dims
|
||||
self.skip_mode = skip_mode
|
||||
self.__setattr_flag__ = True
|
||||
|
||||
def __call__(self, x, axis=()):
|
||||
args = [x, axis]
|
||||
output = _run_op(self, self.name, args)
|
||||
return output
|
||||
|
||||
def infer_value(self, input_x, axis):
|
||||
""" return reduce op value"""
|
||||
value = None
|
||||
if input_x is not None and axis is not None:
|
||||
value = input_x.asnumpy()
|
||||
if isinstance(axis, int):
|
||||
pass
|
||||
elif axis:
|
||||
axis = tuple(set(axis))
|
||||
elif axis in ((), []) and self.skip_mode:
|
||||
return input_x
|
||||
else:
|
||||
axis = tuple(range(len(value.shape)))
|
||||
value = np.sum(value, axis, keepdims=self.keep_dims)
|
||||
value = np.array(value)
|
||||
value = Tensor(value)
|
||||
return value
|
||||
|
||||
|
||||
class ReduceAll(_Reduce):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue