forked from mindspore-Ecosystem/mindspore
!15539 clean pclint warning
From: @lianliguang Reviewed-by: @jjfeing,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
9b965888aa
|
@ -758,15 +758,15 @@ std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
|
|||
}
|
||||
std::vector<int64_t> axis_list;
|
||||
if (axis_attr->isa<Int64Imm>()) {
|
||||
axis_list.emplace_back(GetValue<int64_t>(axis_attr));
|
||||
(void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
|
||||
} else {
|
||||
axis_list = GetValue<std::vector<int64_t>>(axis_attr);
|
||||
}
|
||||
for (const auto &elem : axis_list) {
|
||||
if (elem < 0) {
|
||||
axis.emplace_back(input_shape.size() + elem);
|
||||
(void)axis.emplace_back(input_shape.size() + elem);
|
||||
} else {
|
||||
axis.emplace_back(elem);
|
||||
(void)axis.emplace_back(elem);
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
|
||||
|
|
|
@ -65,7 +65,7 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<K
|
|||
std::vector<TypeId> inputs_type{};
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
|
||||
(void)inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index));
|
||||
inputs_type.push_back(type);
|
||||
}
|
||||
std::vector<std::string> outputs_format;
|
||||
|
|
|
@ -209,8 +209,8 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::str
|
|||
if (index >= kernel_build_info_->input_reshape_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "index outof range!";
|
||||
}
|
||||
std::copy(input_reshape_type.begin(), input_reshape_type.end(),
|
||||
std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
|
||||
(void)std::copy(input_reshape_type.begin(), input_reshape_type.end(),
|
||||
std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::string &output_reshape_type,
|
||||
|
@ -218,8 +218,8 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::st
|
|||
if (index >= kernel_build_info_->output_reshape_type_.size()) {
|
||||
MS_LOG(EXCEPTION) << "index outof range!";
|
||||
}
|
||||
std::copy(output_reshape_type.begin(), output_reshape_type.end(),
|
||||
std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
|
||||
(void)std::copy(output_reshape_type.begin(), output_reshape_type.end(),
|
||||
std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
|
||||
|
|
|
@ -73,7 +73,7 @@
|
|||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_basiclstm_op.h"
|
||||
#include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h"
|
||||
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
|
||||
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
||||
#include "backend/optimizer/ascend/format_type/change_axis_of_reduce_kernel.h"
|
||||
#include "backend/optimizer/ascend/format_type/convert_cast_format.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/pass/optimize_dependence.h"
|
||||
|
|
|
@ -1,103 +1,103 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "utils/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
using ConvertFunction = std::function<void(const CNodePtr &)>;
|
||||
|
||||
void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode);
|
||||
const size_t kAxis_H = 2;
|
||||
const size_t kAxis_W = 3;
|
||||
const size_t kAxis_6HD_H = 1;
|
||||
const size_t kAxis_6HD_W = 2;
|
||||
const std::map<std::string, ConvertFunction> kReduceConvertMap = {{kOpFormat_FRAC_Z, ConvertReduceAttrFraczAnd6HD},
|
||||
{kOpFormat_C1HWNCoC0, ConvertReduceAttrFraczAnd6HD}};
|
||||
void SafeCheckFunction(const CNodePtr &cnode, const std::vector<int64_t> &reduce_axis) {
|
||||
if (reduce_axis.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << cnode->DebugString() << "'s reduce axis got a empty vector";
|
||||
}
|
||||
if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
|
||||
AnfAlgo::GetInputTensorNum(cnode) != 1) {
|
||||
MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
|
||||
<< "] is not single input or single output ";
|
||||
}
|
||||
for (auto elem : reduce_axis) {
|
||||
if (elem > 4) {
|
||||
MS_LOG(INFO) << "reduce axis is larger than 4 dims reduce axis : [" << elem << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode) {
|
||||
auto axis = kernel::GetReduceAttrAxis(cnode);
|
||||
std::vector<int64_t> convert_axis;
|
||||
SafeCheckFunction(cnode, axis);
|
||||
auto format = AnfAlgo::GetInputFormat(cnode, 0);
|
||||
if (format != kOpFormat_FRAC_Z && format != kOpFormat_C1HWNCoC0) {
|
||||
MS_LOG(EXCEPTION) << "The node [" << cnode->DebugString() << "] format " << format
|
||||
<< " is not needed to change the axis";
|
||||
}
|
||||
for (auto elem : axis) {
|
||||
switch (elem) {
|
||||
case kAxis_H:
|
||||
convert_axis.emplace_back(kAxis_6HD_H);
|
||||
break;
|
||||
case kAxis_W:
|
||||
convert_axis.emplace_back(kAxis_6HD_W);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(INFO) << "reduce axis is axis : [" << elem << "]"
|
||||
<< " but the format is not supported this reduce axis";
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(convert_axis), cnode);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ChangeAxisOfReduceKernel::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({X, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) {
|
||||
return nullptr;
|
||||
}
|
||||
auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0));
|
||||
if (convert_map == kReduceConvertMap.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
convert_map->second(node->cast<CNodePtr>());
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/format_type/change_axis_of_reduce_kernel.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "utils/utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
using ConvertFunction = std::function<void(const CNodePtr &)>;
|
||||
|
||||
void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode);
|
||||
const size_t kAxis_H = 2;
|
||||
const size_t kAxis_W = 3;
|
||||
const size_t kAxis_6HD_H = 1;
|
||||
const size_t kAxis_6HD_W = 2;
|
||||
const std::map<std::string, ConvertFunction> kReduceConvertMap = {{kOpFormat_FRAC_Z, ConvertReduceAttrFraczAnd6HD},
|
||||
{kOpFormat_C1HWNCoC0, ConvertReduceAttrFraczAnd6HD}};
|
||||
void SafeCheckFunction(const CNodePtr &cnode, const std::vector<int64_t> &reduce_axis) {
|
||||
if (reduce_axis.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << cnode->DebugString() << "'s reduce axis got a empty vector";
|
||||
}
|
||||
if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
|
||||
AnfAlgo::GetInputTensorNum(cnode) != 1) {
|
||||
MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
|
||||
<< "] is not single input or single output ";
|
||||
}
|
||||
for (auto elem : reduce_axis) {
|
||||
if (elem > 4) {
|
||||
MS_LOG(INFO) << "reduce axis is larger than 4 dims reduce axis : [" << elem << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ConvertReduceAttrFraczAnd6HD(const CNodePtr &cnode) {
|
||||
auto axis = kernel::GetReduceAttrAxis(cnode);
|
||||
std::vector<int64_t> convert_axis;
|
||||
SafeCheckFunction(cnode, axis);
|
||||
auto format = AnfAlgo::GetInputFormat(cnode, 0);
|
||||
if (format != kOpFormat_FRAC_Z && format != kOpFormat_C1HWNCoC0) {
|
||||
MS_LOG(EXCEPTION) << "The node [" << cnode->DebugString() << "] format " << format
|
||||
<< " is not needed to change the axis";
|
||||
}
|
||||
for (auto elem : axis) {
|
||||
switch (elem) {
|
||||
case kAxis_H:
|
||||
convert_axis.emplace_back(kAxis_6HD_H);
|
||||
break;
|
||||
case kAxis_W:
|
||||
convert_axis.emplace_back(kAxis_6HD_W);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(INFO) << "reduce axis is axis : [" << elem << "]"
|
||||
<< " but the format is not supported this reduce axis";
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(convert_axis), cnode);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef ChangeAxisOfReduceKernel::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({X, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) {
|
||||
return nullptr;
|
||||
}
|
||||
auto convert_map = kReduceConvertMap.find(AnfAlgo::GetInputFormat(node, 0));
|
||||
if (convert_map == kReduceConvertMap.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
convert_map->second(node->cast<CNodePtr>());
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,33 +1,33 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ChangeAxisOfReduceKernel : public PatternProcessPass {
|
||||
public:
|
||||
explicit ChangeAxisOfReduceKernel(bool multigraph = true)
|
||||
: PatternProcessPass("change_axis_of_reduce_kernel", multigraph) {}
|
||||
~ChangeAxisOfReduceKernel() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ChangeAxisOfReduceKernel : public PatternProcessPass {
|
||||
public:
|
||||
explicit ChangeAxisOfReduceKernel(bool multigraph = true)
|
||||
: PatternProcessPass("change_axis_of_reduce_kernel", multigraph) {}
|
||||
~ChangeAxisOfReduceKernel() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CHANGE_AXIS_OF_REDUCE_KENRNEL_H_
|
|
@ -13,11 +13,11 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H_
|
||||
#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H_
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_helper.h"
|
||||
#ifndef MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
|
||||
#define MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvertUnSupportNodeToAICPU : public PatternProcessPass {
|
||||
|
@ -34,4 +34,4 @@ class ConvertUnSupportNodeToAICPU : public PatternProcessPass {
|
|||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H
|
||||
#endif // MINDSPORE_CONVERT_UNSUPPORTED_NODE_TO_AICPU_H_
|
||||
|
|
|
@ -58,7 +58,7 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) {
|
|||
auto new_node = kernel_graph->NewCNode(cnode);
|
||||
auto manager = kernel_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->Replace(cnode, new_node);
|
||||
(void)manager->Replace(cnode, new_node);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,12 +127,12 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
|
|||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t index = 0; index < output_num; ++index) {
|
||||
if (cast_index == index) {
|
||||
shapes.emplace_back(cast_shape);
|
||||
types.emplace_back(cast_dtype);
|
||||
(void)shapes.emplace_back(cast_shape);
|
||||
(void)types.emplace_back(cast_dtype);
|
||||
continue;
|
||||
}
|
||||
shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index));
|
||||
types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index));
|
||||
(void)shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index));
|
||||
(void)types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get());
|
||||
}
|
||||
|
|
|
@ -763,7 +763,7 @@ AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitiv
|
|||
if (input_index >= input_abstract.size()) {
|
||||
MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract " << input_abstract.size();
|
||||
}
|
||||
rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
|
||||
(void)rectifyed_abs_list.emplace_back(input_abstract[input_index++]);
|
||||
} else {
|
||||
if (item < 0) {
|
||||
MS_LOG(EXCEPTION) << " the dynamic input size check error the index should be -1 or positive number but got "
|
||||
|
@ -775,9 +775,9 @@ AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitiv
|
|||
MS_LOG(EXCEPTION) << " index " << input_index << " is out of range in input abstract "
|
||||
<< input_abstract.size();
|
||||
}
|
||||
dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
|
||||
(void)dynamic_inputs_abs.emplace_back(input_abstract[input_index++]);
|
||||
}
|
||||
rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
|
||||
(void)rectifyed_abs_list.emplace_back(std::make_shared<abstract::AbstractTuple>(dynamic_inputs_abs));
|
||||
}
|
||||
}
|
||||
return rectifyed_abs_list;
|
||||
|
|
|
@ -44,7 +44,7 @@ int64_t SplitTupleInputs(const FuncGraphPtr &graph, const AnfNodePtr &tuple_inpu
|
|||
// using for graph kernel
|
||||
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
|
||||
MS_EXCEPTION_IF_NULL(dyn_input_node);
|
||||
plant_inputs->emplace_back(dyn_input_node);
|
||||
(void)plant_inputs->emplace_back(dyn_input_node);
|
||||
}
|
||||
return input_size;
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt
|
|||
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (AnfAlgo::IsTupleOutput(input_node)) {
|
||||
dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
||||
(void)dyn_input_sizes.emplace_back(SplitTupleInputs(graph, input_node, &plant_inputs));
|
||||
} else {
|
||||
dyn_input_sizes.push_back(-1);
|
||||
plant_inputs.push_back(input_node);
|
||||
|
|
|
@ -1498,23 +1498,23 @@ void AnfRuntimeAlgorithm::ReorderOptimizerExecList(NotNull<std::vector<CNodePtr>
|
|||
};
|
||||
|
||||
if (trans_pose_func(node)) {
|
||||
transpose_list.emplace_back(node);
|
||||
(void)transpose_list.emplace_back(node);
|
||||
} else if (trans_data_func(node)) {
|
||||
trans_list.emplace_back(node);
|
||||
(void)trans_list.emplace_back(node);
|
||||
} else if (cast_func(node)) {
|
||||
cast_list.emplace_back(node);
|
||||
(void)cast_list.emplace_back(node);
|
||||
} else if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
|
||||
all_opt_list.emplace_back(node);
|
||||
(void)all_opt_list.emplace_back(node);
|
||||
} else {
|
||||
non_opt_list.emplace_back(node);
|
||||
(void)non_opt_list.emplace_back(node);
|
||||
}
|
||||
}
|
||||
node_list->clear();
|
||||
std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
|
||||
std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
||||
std::copy(transpose_list.begin(), transpose_list.end(), std::back_inserter(*node_list));
|
||||
std::copy(trans_list.begin(), trans_list.end(), std::back_inserter(*node_list));
|
||||
std::copy(cast_list.begin(), cast_list.end(), std::back_inserter(*node_list));
|
||||
(void)std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
|
||||
(void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
|
||||
(void)std::copy(transpose_list.begin(), transpose_list.end(), std::back_inserter(*node_list));
|
||||
(void)std::copy(trans_list.begin(), trans_list.end(), std::back_inserter(*node_list));
|
||||
(void)std::copy(cast_list.begin(), cast_list.end(), std::back_inserter(*node_list));
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::ReorderPosteriorExecList(NotNull<std::vector<CNodePtr> *> node_list) {
|
||||
|
|
|
@ -158,7 +158,7 @@ bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map<std::string, std::strin
|
|||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
|
||||
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const {
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
if (value_node == nullptr) {
|
||||
return nullptr;
|
||||
|
@ -534,7 +534,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
|
|||
std::vector<std::string> formats = {kOpFormat_DEFAULT};
|
||||
if (node->isa<ValueNode>()) {
|
||||
kernel_info->set_feature_map_flag(false);
|
||||
types.emplace_back(kTypeUnknown);
|
||||
(void)types.emplace_back(kTypeUnknown);
|
||||
auto value_node = node->cast<ValueNodePtr>();
|
||||
SyncDeviceInfoToValueNode(value_node, &formats, &types);
|
||||
}
|
||||
|
|
|
@ -296,7 +296,7 @@ class KernelGraph : public FuncGraph {
|
|||
// remove value node form graph
|
||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||
void SetKernelInfoForNode(const AnfNodePtr &node) const;
|
||||
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
|
||||
AnfNodePtr MakeValueNode(const AnfNodePtr &node) const;
|
||||
void EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
|
||||
// update node edge list
|
||||
|
|
|
@ -1794,7 +1794,7 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr
|
|||
result.insert(result.end(), res.begin(), res.end());
|
||||
continue;
|
||||
}
|
||||
result.emplace_back(user.first);
|
||||
(void)result.emplace_back(user.first);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -187,7 +187,7 @@ AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed.";
|
||||
}
|
||||
|
||||
TypePtr mode_t = mode_v->cast<TypePtr>();
|
||||
auto mode_t = mode_v->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
bool v = IsSubtype(args_spec_list[0], mode_t);
|
||||
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
|
||||
|
@ -252,10 +252,10 @@ AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const
|
|||
auto arg_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
auto arg_y = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
|
||||
auto arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(arg_x_value);
|
||||
|
||||
ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
|
||||
auto arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(arg_y_value);
|
||||
|
||||
const std::vector<ValuePtr> x_shape = arg_x_value->value();
|
||||
|
@ -619,9 +619,9 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt
|
|||
MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
|
||||
}
|
||||
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
auto value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
TypePtr type_ptr = value_track->cast<TypePtr>();
|
||||
auto type_ptr = value_track->cast<TypePtr>();
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
|
||||
}
|
||||
|
|
|
@ -60,8 +60,7 @@ AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
#define REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl) \
|
||||
static auto helper_##name = \
|
||||
abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, false);
|
||||
auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, infer_impl, infer_value_impl, false);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
|
||||
|
|
|
@ -601,7 +601,8 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
|
|||
}
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool need_infer_value = eval_impl_.in_white_list_ || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
|
||||
bool need_infer_value =
|
||||
!(eval_impl_.in_white_list_) || (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
|
||||
AbstractBasePtr abs_base = nullptr;
|
||||
ValuePtr value = nullptr;
|
||||
prim_->BeginRecordAddAttr();
|
||||
|
|
|
@ -52,7 +52,7 @@ void SyncData(const py::object &arg) {
|
|||
}
|
||||
if (py::isinstance<tensor::Tensor>(arg)) {
|
||||
auto tensor = py::cast<tensor::TensorPtr>(arg);
|
||||
(void)tensor->data_sync();
|
||||
tensor->data_sync();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
@ -234,7 +234,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
|
|||
obj = py_args[2];
|
||||
}
|
||||
CheckHookConsistency(obj, py_args[2]);
|
||||
hook_grad_.erase(cell_id);
|
||||
(void)hook_grad_.erase(cell_id);
|
||||
} else {
|
||||
hook_grad_[cell_id] = py_args[2];
|
||||
obj = py_args[2];
|
||||
|
|
|
@ -57,7 +57,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
auto key_string = GetValue<std::string>(keyPtr);
|
||||
key_value.emplace_back(key_string, value_list[index]);
|
||||
(void)key_value.emplace_back(key_string, value_list[index]);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
}
|
||||
|
|
|
@ -29,17 +29,17 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
// check
|
||||
CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("bias rank", b_shape.size(), kEqual, 1, prim_name);
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
auto x_channel = x_shape[1];
|
||||
if (format != NCHW) {
|
||||
x_channel = x_shape[x_shape.size() - 1];
|
||||
}
|
||||
CheckAndConvertUtils::Check("b_shape[0]", b_shape[0], kEqual, "x_shape[1]", x_channel, prim_name);
|
||||
(void)CheckAndConvertUtils::Check("b_shape[0]", b_shape[0], kEqual, "x_shape[1]", x_channel, prim_name);
|
||||
|
||||
std::vector<int64_t> out_shape = x_shape;
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
|
@ -53,14 +53,14 @@ TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("input_x", input_args[0]->BuildType());
|
||||
types.emplace("bias", input_args[1]->BuildType());
|
||||
(void)types.emplace("input_x", input_args[0]->BuildType());
|
||||
(void)types.emplace("bias", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
|
||||
}
|
||||
} // namespace
|
||||
void BiasAdd::set_format(const Format &format) {
|
||||
int64_t f = format;
|
||||
this->AddAttr(kFormat, MakeValue(f));
|
||||
(void)this->AddAttr(kFormat, MakeValue(f));
|
||||
}
|
||||
Format BiasAdd::get_format() const {
|
||||
auto value_ptr = GetAttr(kFormat);
|
||||
|
|
|
@ -59,7 +59,7 @@ std::vector<int64_t> SetPadList(const PrimitivePtr &primitive, const std::vector
|
|||
pad_list[3] = pad_needed_h - pad_left;
|
||||
} else if (pad_mode == PAD) {
|
||||
auto pad = GetValue<std::vector<int64_t>>(primitive->GetAttr(kPad));
|
||||
std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list));
|
||||
(void)std::copy(pad.begin(), pad.end(), std::back_inserter(pad_list));
|
||||
auto pad_top = pad[0];
|
||||
auto pad_bottom = pad[1];
|
||||
auto pad_right = pad[2];
|
||||
|
@ -70,15 +70,16 @@ std::vector<int64_t> SetPadList(const PrimitivePtr &primitive, const std::vector
|
|||
h_out = floor(h_out);
|
||||
w_out = floor(w_out);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name());
|
||||
primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name())));
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, primitive->name());
|
||||
(void)primitive->AddAttr(kPadList,
|
||||
MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, primitive->name())));
|
||||
std::vector<int64_t> out_shape = {x_shape[0], out_channel, h_out, w_out};
|
||||
return out_shape;
|
||||
}
|
||||
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInRange<size_t>("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
|
||||
|
@ -86,16 +87,16 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
x_shape = {x_shape[0], x_shape[3], x_shape[1], x_shape[2]};
|
||||
w_shape = {w_shape[0], w_shape[3], w_shape[1], w_shape[2]};
|
||||
}
|
||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)), kEqual,
|
||||
"w_shape[1]", w_shape[1], prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
(void)CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / GetValue<int64_t>(primitive->GetAttr(kGroup)),
|
||||
kEqual, "w_shape[1]", w_shape[1], prim_name);
|
||||
auto out_channel = GetValue<int64_t>(primitive->GetAttr(kOutChannel));
|
||||
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
|
||||
(void)CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], prim_name);
|
||||
std::vector<int64_t> temp_w;
|
||||
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
||||
CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)), kEqual,
|
||||
"w_shape[2:4]", temp_w, prim_name);
|
||||
(void)std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
||||
(void)CheckAndConvertUtils::Check("kernel_size", GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize)),
|
||||
kEqual, "w_shape[2:4]", temp_w, prim_name);
|
||||
auto out_shape = SetPadList(primitive, w_shape, x_shape, out_channel);
|
||||
if (format == NHWC) {
|
||||
out_shape = {out_shape[0], out_shape[3], out_shape[1], out_shape[2]};
|
||||
|
@ -104,14 +105,14 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
}
|
||||
|
||||
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckInRange<size_t>("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt32, kInt64, kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
@ -160,7 +161,7 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) {
|
|||
}
|
||||
|
||||
void Conv2D::set_pad(const std::vector<int64_t> &pad) {
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
(void)CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name());
|
||||
AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name())));
|
||||
}
|
||||
|
||||
|
|
|
@ -57,8 +57,8 @@ TypePtr MatMulInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
}
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -90,7 +90,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
auto pad_mode_value = (primitive->GetAttr(kPadMode));
|
||||
PadMode pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
|
||||
auto pad_mode = PadMode(GetValue<int64_t>(pad_mode_value));
|
||||
auto batch = in_shape[0];
|
||||
auto channel = in_shape[1];
|
||||
auto in_h = in_shape[2];
|
||||
|
|
|
@ -36,7 +36,7 @@ AbstractBasePtr MergeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
args.insert({"input[" + std::to_string(i) + "]", inputs_type[i]});
|
||||
}
|
||||
std::set<TypePtr> template_type = common_valid_types;
|
||||
template_type.emplace(kBool);
|
||||
(void)template_type.emplace(kBool);
|
||||
auto infered_type = CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, template_type, op_name);
|
||||
std::vector<int64_t> in_shape0 = inputs_shape[0]->cast<abstract::ShapePtr>()->shape();
|
||||
|
||||
|
|
|
@ -22,8 +22,8 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
void PrimitiveC::InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name) {
|
||||
this->AddAttr("input_names", MakeValue(inputs_name));
|
||||
this->AddAttr("output_names", MakeValue(outputs_name));
|
||||
(void)this->AddAttr("input_names", MakeValue(inputs_name));
|
||||
(void)this->AddAttr("output_names", MakeValue(outputs_name));
|
||||
}
|
||||
|
||||
AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
|
||||
|
|
|
@ -39,10 +39,10 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
std::set<TypePtr> valid_params_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("shape type", input_args[0]->BuildType(), valid_params_types, op_name);
|
||||
AbstractBasePtrList abs_list;
|
||||
std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list),
|
||||
[](int64_t item) -> std::shared_ptr<abstract::AbstractScalar> {
|
||||
return std::make_shared<abstract::AbstractScalar>(item);
|
||||
});
|
||||
(void)std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list),
|
||||
[](int64_t item) -> std::shared_ptr<abstract::AbstractScalar> {
|
||||
return std::make_shared<abstract::AbstractScalar>(item);
|
||||
});
|
||||
auto abs = std::make_shared<abstract::AbstractTuple>(abs_list);
|
||||
return abs;
|
||||
}
|
||||
|
|
|
@ -433,7 +433,7 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty
|
|||
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s " << types.begin()->first
|
||||
<< " input must be a tensor but got " << type->ToString();
|
||||
}
|
||||
TypePtr check_type = _CheckTypeSame(types, prim_name, false);
|
||||
auto check_type = _CheckTypeSame(types, prim_name, false);
|
||||
return CheckTypeValid(types.begin()->first, check_type, check_list, prim_name);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "ir/tensor.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
namespace mindspore {
|
||||
template <typename T>
|
||||
void SetTensorData(void *data, T num, size_t data_length) {
|
||||
|
|
Loading…
Reference in New Issue