!5443 add tensor constructor

Merge pull request !5443 from lianliguang/run-graph-test
This commit is contained in:
mindspore-ci-bot 2020-08-31 11:19:35 +08:00 committed by Gitee
commit f7900d6adf
14 changed files with 179 additions and 28 deletions

View File

@ -444,14 +444,9 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod
if (!anf_node->isa<CNode>()) { if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
} }
auto cnode = anf_node->cast<CNodePtr>(); auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(input_node);
if (input_idx + 1 >= cnode->inputs().size()) { return VisitKernelWithReturnType(input_node, 0);
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
return VisitKernelWithReturnType(node, 0);
} }
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
@ -975,7 +970,7 @@ bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto get_input_index = index + 1; auto get_input_index = index + 1;
if (index + 1 > node->inputs().size()) { if (index + 1 >= node->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
<< node->inputs().size(); << node->inputs().size();
} }

View File

@ -1061,5 +1061,10 @@ void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
} }
} }
} }
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) {
RunInfer(func_graph, inputs);
return CompileGraph(func_graph);
}
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

View File

@ -52,6 +52,7 @@ class AscendSession : public SessionBasic {
} }
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildGraph(GraphId) override; void BuildGraph(GraphId) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,

View File

@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include "c_ops/primitive_c.h"
#include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/data_converter.h"
#include "ir/manager.h" #include "ir/manager.h"
#include "ir/param_info.h" #include "ir/param_info.h"
@ -1039,6 +1040,45 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
auto node_list = TopoSort(func_graph->get_return());
size_t tensor_index = 0;
for (const auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
AbstractBasePtrList input_abstracts;
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), index);
MS_EXCEPTION_IF_NULL(input_node);
auto abstract = input_node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
input_abstracts.emplace_back(abstract);
}
auto prim = AnfAlgo::GetCNodePrimitive(node);
if (prim->isa<PrimitiveC>()) {
auto prim_c = prim->cast<std::shared_ptr<PrimitiveC>>();
MS_EXCEPTION_IF_NULL(prim_c);
auto abstract = prim_c->Infer(input_abstracts);
node->set_abstract(abstract);
} else {
node->set_abstract(
std::make_shared<tensor::Tensor>(kNumberTypeFloat32, std::vector<int>{32, 64, 218, 218})->ToAbstract());
}
} else if (node->isa<Parameter>()) {
if (tensor_index > inputs.size()) {
MS_EXCEPTION(IndexError) << "Index " << tensor_index << "is out of " << inputs.size() << "tensor's size";
}
node->set_abstract(inputs[tensor_index++]->ToAbstract());
} else {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
value_node->set_abstract(value->ToAbstract());
}
}
}
void SessionBasic::SetSummaryNodes(KernelGraph *graph) { void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start"; MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);

View File

@ -70,6 +70,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; } virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
MS_EXCEPTION(NotExistsError) << "Call an empty function";
}
// build graph, used to handle multiple child graphs // build graph, used to handle multiple child graphs
virtual void BuildGraph(GraphId) {} virtual void BuildGraph(GraphId) {}
@ -129,6 +132,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs); void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
protected: protected:
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
// Get graph by graph id ,if not exist return null ptr // Get graph by graph id ,if not exist return null ptr
KernelGraphPtr GetGraph(GraphId graph_id) const; KernelGraphPtr GetGraph(GraphId graph_id) const;

View File

@ -37,10 +37,10 @@ constexpr auto kPadList = "pad_list";
constexpr auto kConv2DName = "Conv2D"; constexpr auto kConv2DName = "Conv2D";
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = std::dynamic_pointer_cast<Conv2d>(primitive); auto conv_prim = primitive->cast<PrimConv2dPtr>();
MS_EXCEPTION_IF_NULL(conv_prim); MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name(); auto prim_name = conv_prim->name();
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
@ -99,7 +99,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
} }
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) { TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
for (const auto &item : input_args) { for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }

View File

@ -29,6 +29,7 @@ class Conv2d : public PrimitiveC {
public: public:
Conv2d(); Conv2d();
~Conv2d() = default; ~Conv2d() = default;
MS_DECLARE_PARENT(Conv2d, PrimitiveC);
void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid", void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1}, const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1); const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);

View File

@ -25,6 +25,7 @@ namespace mindspore {
class PrimitiveC : public Primitive { class PrimitiveC : public Primitive {
public: public:
explicit PrimitiveC(const std::string &name) : Primitive(name) {} explicit PrimitiveC(const std::string &name) : Primitive(name) {}
MS_DECLARE_PARENT(PrimitiveC, Primitive);
~PrimitiveC() = default; ~PrimitiveC() = default;
AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list); AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list);

View File

@ -640,7 +640,7 @@ CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<An
return NewCNode(input_node_list); return NewCNode(input_node_list);
} }
ParameterPtr FuncGraph::add_parameter(const tensor::MetaTensorPtr &meta_tensor) { ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
auto parameter = add_parameter(); auto parameter = add_parameter();
parameter->set_default_param(MakeValue(meta_tensor)); parameter->set_default_param(MakeValue(meta_tensor));
parameter->set_abstract(meta_tensor->ToAbstract()); parameter->set_abstract(meta_tensor->ToAbstract());

View File

@ -173,7 +173,7 @@ class FuncGraph : public FuncGraphBase {
CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope); CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope);
virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs); virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
virtual ParameterPtr add_parameter(const tensor::MetaTensorPtr &meta_tensor); virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor);
// Functions for handling variable argument, keyword-only arguments and variable keyword argument // Functions for handling variable argument, keyword-only arguments and variable keyword argument
AnfNodePtr GetDefaultValueByName(const std::string &name); AnfNodePtr GetDefaultValueByName(const std::string &name);
void set_param_default_value(const std::string &name, const AnfNodePtr &node) { void set_param_default_value(const std::string &name, const AnfNodePtr &node) {

View File

@ -64,11 +64,7 @@ std::vector<int> CheckAndConvertUtils::CheckPositiveVector(const std::string &ar
const std::vector<int> &arg_value, const std::vector<int> &arg_value,
const std::string &prim_name, bool allow_four, const std::string &prim_name, bool allow_four,
bool ret_four) { bool ret_four) {
if (arg_value.size() == 2) { auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void {
return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[1]} : arg_value;
} else if (arg_value.size() == 4 && allow_four) {
return ret_four ? arg_value : std::vector<int>{arg_value[2], arg_value[3]};
}
std::ostringstream buffer; std::ostringstream buffer;
buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two ";
if (allow_four) { if (allow_four) {
@ -80,7 +76,24 @@ std::vector<int> CheckAndConvertUtils::CheckPositiveVector(const std::string &ar
} }
buffer << "]"; buffer << "]";
MS_EXCEPTION(ValueError) << buffer.str(); MS_EXCEPTION(ValueError) << buffer.str();
};
for (auto item : arg_value) {
if (item < 0) {
raise_message();
}
}
if (arg_value.size() == 1) {
return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[0]} : std::vector<int>{arg_value[0], arg_value[0]};
}
if (arg_value.size() == 2) {
return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[1]} : arg_value;
} else if (arg_value.size() == 4 && allow_four) {
return ret_four ? arg_value : std::vector<int>{arg_value[2], arg_value[3]};
}
raise_message();
return arg_value;
} }
std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value, std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value,
const std::set<std::string> &check_list, const std::string &prim_name) { const std::set<std::string> &check_list, const std::string &prim_name) {
if (check_list.find(arg_value) != check_list.end()) { if (check_list.find(arg_value) != check_list.end()) {
@ -131,6 +144,10 @@ void CheckAndConvertUtils::CheckInRange(const std::string &arg_name, int arg_val
if (iter == kCompareRangeMap.end()) { if (iter == kCompareRangeMap.end()) {
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
} }
if (range.first >= range.second) {
MS_EXCEPTION(ArgumentError) << "the check range left must be larger than right number bug got [ " << range.first
<< "," << range.second;
}
if (iter->second(arg_value, range)) { if (iter->second(arg_value, range)) {
return; return;
} }

View File

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H #ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
#define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H #define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
#include <vector> #include <vector>
#include <string> #include <string>
#include <map> #include <map>
@ -67,4 +67,4 @@ class CheckAndConvertUtils {
static bool IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2); static bool IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2);
}; };
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H #endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_

View File

@ -0,0 +1,59 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/tensor_construct_utils.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace {
template <typename T>
void SetTensorData(void *data, float num, size_t data_length) {
MS_EXCEPTION_IF_NULL(data);
auto tensor_data = reinterpret_cast<T *>(data);
MS_EXCEPTION_IF_NULL(tensor_data);
for (size_t index = 0; index < data_length; ++index) {
*tensor_data = num;
++tensor_data;
}
}
} // namespace
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int> &shape) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum());
auto tensor_data = tensor->data_c();
char *data = reinterpret_cast<char *>(tensor_data);
MS_EXCEPTION_IF_NULL(data);
(void)memset_s(data, mem_size, 0, mem_size);
return tensor;
}
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int> &shape) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
auto mem_size = IntToSize(tensor->ElementsNum());
if (tensor->data_type() == kNumberTypeFloat32) {
SetTensorData<float>(tensor->data_c(), 1.0, mem_size);
} else if (tensor->data_type() == kNumberTypeInt) {
SetTensorData<int>(tensor->data_c(), 1, mem_size);
}
return tensor;
}
tensor::TensorPtr TensorConstructUtils::CreateTensor(TypeId type, const std::vector<int> &shape, void *data) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape, data, type);
return tensor;
}
} // namespace mindspore

View File

@ -0,0 +1,28 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
#define MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
#include <vector>
#include "ir/tensor.h"
namespace mindspore {
class TensorConstructUtils {
public:
static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int> &shape);
static tensor::TensorPtr CreateOnesTensor(TypeId type, const std::vector<int> &shape);
static tensor::TensorPtr CreateTensor(TypeId type, const std::vector<int> &shape, void *data);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_