forked from mindspore-Ecosystem/mindspore
!5443 add tensor constructor
Merge pull request !5443 from lianliguang/run-graph-test
This commit is contained in:
commit
f7900d6adf
|
@ -444,14 +444,9 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod
|
||||||
if (!anf_node->isa<CNode>()) {
|
if (!anf_node->isa<CNode>()) {
|
||||||
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
|
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
|
||||||
}
|
}
|
||||||
auto cnode = anf_node->cast<CNodePtr>();
|
auto input_node = AnfAlgo::GetInputNode(anf_node->cast<CNodePtr>(), input_idx);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
if (input_idx + 1 >= cnode->inputs().size()) {
|
return VisitKernelWithReturnType(input_node, 0);
|
||||||
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
|
|
||||||
}
|
|
||||||
auto node = cnode->input(input_idx + 1);
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
return VisitKernelWithReturnType(node, 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
|
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
|
||||||
|
@ -975,7 +970,7 @@ bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
|
||||||
AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
|
AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
auto get_input_index = index + 1;
|
auto get_input_index = index + 1;
|
||||||
if (index + 1 > node->inputs().size()) {
|
if (index + 1 >= node->inputs().size()) {
|
||||||
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
|
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
|
||||||
<< node->inputs().size();
|
<< node->inputs().size();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1061,5 +1061,10 @@ void AscendSession::UpdateRefOutputMap(NotNull<KernelGraphPtr> graph,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph, const vector<tensor::TensorPtr> &inputs) {
|
||||||
|
RunInfer(func_graph, inputs);
|
||||||
|
return CompileGraph(func_graph);
|
||||||
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -52,6 +52,7 @@ class AscendSession : public SessionBasic {
|
||||||
}
|
}
|
||||||
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||||
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
|
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
|
||||||
|
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) override;
|
||||||
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
|
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
|
||||||
void BuildGraph(GraphId) override;
|
void BuildGraph(GraphId) override;
|
||||||
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include "c_ops/primitive_c.h"
|
||||||
#include "pipeline/jit/parse/data_converter.h"
|
#include "pipeline/jit/parse/data_converter.h"
|
||||||
#include "ir/manager.h"
|
#include "ir/manager.h"
|
||||||
#include "ir/param_info.h"
|
#include "ir/param_info.h"
|
||||||
|
@ -1039,6 +1040,45 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
|
||||||
|
|
||||||
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
|
void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
|
||||||
|
|
||||||
|
void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
|
||||||
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
|
size_t tensor_index = 0;
|
||||||
|
for (const auto &node : node_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (node->isa<CNode>()) {
|
||||||
|
AbstractBasePtrList input_abstracts;
|
||||||
|
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
|
||||||
|
auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), index);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_node);
|
||||||
|
auto abstract = input_node->abstract();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract);
|
||||||
|
input_abstracts.emplace_back(abstract);
|
||||||
|
}
|
||||||
|
auto prim = AnfAlgo::GetCNodePrimitive(node);
|
||||||
|
if (prim->isa<PrimitiveC>()) {
|
||||||
|
auto prim_c = prim->cast<std::shared_ptr<PrimitiveC>>();
|
||||||
|
MS_EXCEPTION_IF_NULL(prim_c);
|
||||||
|
auto abstract = prim_c->Infer(input_abstracts);
|
||||||
|
node->set_abstract(abstract);
|
||||||
|
} else {
|
||||||
|
node->set_abstract(
|
||||||
|
std::make_shared<tensor::Tensor>(kNumberTypeFloat32, std::vector<int>{32, 64, 218, 218})->ToAbstract());
|
||||||
|
}
|
||||||
|
} else if (node->isa<Parameter>()) {
|
||||||
|
if (tensor_index > inputs.size()) {
|
||||||
|
MS_EXCEPTION(IndexError) << "Index " << tensor_index << "is out of " << inputs.size() << "tensor's size";
|
||||||
|
}
|
||||||
|
node->set_abstract(inputs[tensor_index++]->ToAbstract());
|
||||||
|
} else {
|
||||||
|
auto value_node = node->cast<ValueNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
auto value = value_node->value();
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
value_node->set_abstract(value->ToAbstract());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
|
void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
|
||||||
MS_LOG(DEBUG) << "Update summary Start";
|
MS_LOG(DEBUG) << "Update summary Start";
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
|
|
@ -70,6 +70,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
|
|
||||||
virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
|
virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0;
|
||||||
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
|
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) { return kInvalidGraphId; }
|
||||||
|
virtual GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
|
||||||
|
MS_EXCEPTION(NotExistsError) << "Call an empty function";
|
||||||
|
}
|
||||||
// build graph, used to handle multiple child graphs
|
// build graph, used to handle multiple child graphs
|
||||||
virtual void BuildGraph(GraphId) {}
|
virtual void BuildGraph(GraphId) {}
|
||||||
|
|
||||||
|
@ -129,6 +132,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
|
void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs);
|
||||||
// Get graph by graph id ,if not exist return null ptr
|
// Get graph by graph id ,if not exist return null ptr
|
||||||
KernelGraphPtr GetGraph(GraphId graph_id) const;
|
KernelGraphPtr GetGraph(GraphId graph_id) const;
|
||||||
|
|
||||||
|
|
|
@ -37,10 +37,10 @@ constexpr auto kPadList = "pad_list";
|
||||||
constexpr auto kConv2DName = "Conv2D";
|
constexpr auto kConv2DName = "Conv2D";
|
||||||
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto conv_prim = std::dynamic_pointer_cast<Conv2d>(primitive);
|
auto conv_prim = primitive->cast<PrimConv2dPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(conv_prim);
|
MS_EXCEPTION_IF_NULL(conv_prim);
|
||||||
auto prim_name = conv_prim->name();
|
auto prim_name = conv_prim->name();
|
||||||
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
|
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
|
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
|
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
|
||||||
for (const auto &item : input_args) {
|
for (const auto &item : input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
MS_EXCEPTION_IF_NULL(item);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ class Conv2d : public PrimitiveC {
|
||||||
public:
|
public:
|
||||||
Conv2d();
|
Conv2d();
|
||||||
~Conv2d() = default;
|
~Conv2d() = default;
|
||||||
|
MS_DECLARE_PARENT(Conv2d, PrimitiveC);
|
||||||
void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
|
void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
|
||||||
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
|
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
|
||||||
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
|
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
|
||||||
|
|
|
@ -25,6 +25,7 @@ namespace mindspore {
|
||||||
class PrimitiveC : public Primitive {
|
class PrimitiveC : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit PrimitiveC(const std::string &name) : Primitive(name) {}
|
explicit PrimitiveC(const std::string &name) : Primitive(name) {}
|
||||||
|
MS_DECLARE_PARENT(PrimitiveC, Primitive);
|
||||||
~PrimitiveC() = default;
|
~PrimitiveC() = default;
|
||||||
AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list);
|
AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list);
|
||||||
|
|
||||||
|
|
|
@ -640,7 +640,7 @@ CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector<An
|
||||||
return NewCNode(input_node_list);
|
return NewCNode(input_node_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
ParameterPtr FuncGraph::add_parameter(const tensor::MetaTensorPtr &meta_tensor) {
|
ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) {
|
||||||
auto parameter = add_parameter();
|
auto parameter = add_parameter();
|
||||||
parameter->set_default_param(MakeValue(meta_tensor));
|
parameter->set_default_param(MakeValue(meta_tensor));
|
||||||
parameter->set_abstract(meta_tensor->ToAbstract());
|
parameter->set_abstract(meta_tensor->ToAbstract());
|
||||||
|
|
|
@ -173,7 +173,7 @@ class FuncGraph : public FuncGraphBase {
|
||||||
CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope);
|
CNodePtr NewCNodeWithScope(const std::vector<AnfNodePtr> &inputs, const ScopePtr &scope);
|
||||||
virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
|
virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
|
||||||
|
|
||||||
virtual ParameterPtr add_parameter(const tensor::MetaTensorPtr &meta_tensor);
|
virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor);
|
||||||
// Functions for handling variable argument, keyword-only arguments and variable keyword argument
|
// Functions for handling variable argument, keyword-only arguments and variable keyword argument
|
||||||
AnfNodePtr GetDefaultValueByName(const std::string &name);
|
AnfNodePtr GetDefaultValueByName(const std::string &name);
|
||||||
void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
|
void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
|
||||||
|
|
|
@ -64,23 +64,36 @@ std::vector<int> CheckAndConvertUtils::CheckPositiveVector(const std::string &ar
|
||||||
const std::vector<int> &arg_value,
|
const std::vector<int> &arg_value,
|
||||||
const std::string &prim_name, bool allow_four,
|
const std::string &prim_name, bool allow_four,
|
||||||
bool ret_four) {
|
bool ret_four) {
|
||||||
|
auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two ";
|
||||||
|
if (allow_four) {
|
||||||
|
buffer << "or four ";
|
||||||
|
}
|
||||||
|
buffer << " positive int numbers , but got [";
|
||||||
|
for (auto item : arg_value) {
|
||||||
|
buffer << item << ",";
|
||||||
|
}
|
||||||
|
buffer << "]";
|
||||||
|
MS_EXCEPTION(ValueError) << buffer.str();
|
||||||
|
};
|
||||||
|
for (auto item : arg_value) {
|
||||||
|
if (item < 0) {
|
||||||
|
raise_message();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (arg_value.size() == 1) {
|
||||||
|
return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[0]} : std::vector<int>{arg_value[0], arg_value[0]};
|
||||||
|
}
|
||||||
if (arg_value.size() == 2) {
|
if (arg_value.size() == 2) {
|
||||||
return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[1]} : arg_value;
|
return ret_four ? std::vector<int>{1, 1, arg_value[0], arg_value[1]} : arg_value;
|
||||||
} else if (arg_value.size() == 4 && allow_four) {
|
} else if (arg_value.size() == 4 && allow_four) {
|
||||||
return ret_four ? arg_value : std::vector<int>{arg_value[2], arg_value[3]};
|
return ret_four ? arg_value : std::vector<int>{arg_value[2], arg_value[3]};
|
||||||
}
|
}
|
||||||
std::ostringstream buffer;
|
raise_message();
|
||||||
buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two ";
|
return arg_value;
|
||||||
if (allow_four) {
|
|
||||||
buffer << "or four ";
|
|
||||||
}
|
|
||||||
buffer << " positive int numbers , but got [";
|
|
||||||
for (auto item : arg_value) {
|
|
||||||
buffer << item << ",";
|
|
||||||
}
|
|
||||||
buffer << "]";
|
|
||||||
MS_EXCEPTION(ValueError) << buffer.str();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value,
|
std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value,
|
||||||
const std::set<std::string> &check_list, const std::string &prim_name) {
|
const std::set<std::string> &check_list, const std::string &prim_name) {
|
||||||
if (check_list.find(arg_value) != check_list.end()) {
|
if (check_list.find(arg_value) != check_list.end()) {
|
||||||
|
@ -131,6 +144,10 @@ void CheckAndConvertUtils::CheckInRange(const std::string &arg_name, int arg_val
|
||||||
if (iter == kCompareRangeMap.end()) {
|
if (iter == kCompareRangeMap.end()) {
|
||||||
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
|
MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map";
|
||||||
}
|
}
|
||||||
|
if (range.first >= range.second) {
|
||||||
|
MS_EXCEPTION(ArgumentError) << "the check range left must be larger than right number bug got [ " << range.first
|
||||||
|
<< "," << range.second;
|
||||||
|
}
|
||||||
if (iter->second(arg_value, range)) {
|
if (iter->second(arg_value, range)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
|
#ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
|
||||||
#define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
|
#define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
@ -67,4 +67,4 @@ class CheckAndConvertUtils {
|
||||||
static bool IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2);
|
static bool IsEqualVector(const std::vector<int> &vec_1, const std::vector<int> &vec_2);
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H
|
#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
namespace mindspore {
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
void SetTensorData(void *data, float num, size_t data_length) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
|
auto tensor_data = reinterpret_cast<T *>(data);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_data);
|
||||||
|
for (size_t index = 0; index < data_length; ++index) {
|
||||||
|
*tensor_data = num;
|
||||||
|
++tensor_data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int> &shape) {
|
||||||
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
|
||||||
|
|
||||||
|
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum());
|
||||||
|
auto tensor_data = tensor->data_c();
|
||||||
|
char *data = reinterpret_cast<char *>(tensor_data);
|
||||||
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
|
(void)memset_s(data, mem_size, 0, mem_size);
|
||||||
|
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int> &shape) {
|
||||||
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
|
||||||
|
auto mem_size = IntToSize(tensor->ElementsNum());
|
||||||
|
if (tensor->data_type() == kNumberTypeFloat32) {
|
||||||
|
SetTensorData<float>(tensor->data_c(), 1.0, mem_size);
|
||||||
|
} else if (tensor->data_type() == kNumberTypeInt) {
|
||||||
|
SetTensorData<int>(tensor->data_c(), 1, mem_size);
|
||||||
|
}
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor::TensorPtr TensorConstructUtils::CreateTensor(TypeId type, const std::vector<int> &shape, void *data) {
|
||||||
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape, data, type);
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,28 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
|
||||||
|
#define MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
|
||||||
|
#include <vector>
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
namespace mindspore {
|
||||||
|
class TensorConstructUtils {
|
||||||
|
public:
|
||||||
|
static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int> &shape);
|
||||||
|
static tensor::TensorPtr CreateOnesTensor(TypeId type, const std::vector<int> &shape);
|
||||||
|
static tensor::TensorPtr CreateTensor(TypeId type, const std::vector<int> &shape, void *data);
|
||||||
|
};
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
|
Loading…
Reference in New Issue