forked from mindspore-Ecosystem/mindspore
support tensor attr shape and dtype in graph mode
This commit is contained in:
parent
fa96dfd161
commit
b075674cf2
|
@ -28,7 +28,8 @@ from ...ops.composite.base import _append
|
|||
__all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like']
|
||||
|
||||
trans = P.Transpose()
|
||||
|
||||
shape_ = P.Shape()
|
||||
dtype_ = P.DType()
|
||||
|
||||
def transpose(x):
|
||||
"""Implementation of `transpose`."""
|
||||
|
|
|
@ -93,7 +93,6 @@ inline const PrimitivePtr kPrimArrayToScalar = std::make_shared<Primitive>("arra
|
|||
inline const PrimitivePtr kPrimBroadcastShape = std::make_shared<Primitive>("broadcast_shape");
|
||||
inline const PrimitivePtr kPrimArrayMap = std::make_shared<Primitive>("array_map");
|
||||
inline const PrimitivePtr kPrimArrayReduce = std::make_shared<Primitive>("array_reduce");
|
||||
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
|
||||
inline const PrimitivePtr kPrimCast = std::make_shared<Primitive>("Cast");
|
||||
inline const PrimitivePtr kPrimConcat = std::make_shared<Primitive>("Concat");
|
||||
inline const PrimitivePtr kPrimSqueeze = std::make_shared<Primitive>("Squeeze");
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "frontend/operator/cc_implementations.h"
|
||||
#include "abstract/param_validator.h"
|
||||
|
@ -80,23 +79,6 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
|
|||
return std::make_shared<AbstractTuple>(elems);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
AbstractTensorPtr arg = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_LOG(DEBUG) << "InferImplShape:" << arg->ToString();
|
||||
|
||||
AbstractBasePtrList values;
|
||||
auto shp = arg->shape();
|
||||
for (int entry : shp->shape()) {
|
||||
auto entry_v = MakeValue(entry);
|
||||
values.push_back(std::make_shared<AbstractScalar>(entry_v, entry_v->type()));
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor and a tuple.
|
||||
|
|
|
@ -963,6 +963,7 @@ void ClearResAtexit() {
|
|||
abstract::ClearPrimEvaluatorMap();
|
||||
compile::ClearConvertCache();
|
||||
pipeline::GetMethodMap().clear();
|
||||
pipeline::GetAttrMap().clear();
|
||||
pipeline::ExecutorPy::ClearRes();
|
||||
pipeline::ReclaimOptimizer();
|
||||
pynative::PynativeExecutor::GetInstance()->ClearRes();
|
||||
|
|
|
@ -17,23 +17,20 @@
|
|||
*/
|
||||
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/pipeline.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "debug/draw.h"
|
||||
#include "debug/trace.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "vm/segment_runner.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support opmap definition
|
||||
namespace pipeline {
|
||||
|
||||
MethodMap &GetMethodMap() {
|
||||
static MethodMap method_map = {
|
||||
BuiltInTypeMap &GetMethodMap() {
|
||||
static BuiltInTypeMap method_map = {
|
||||
{kObjectTypeString,
|
||||
{
|
||||
{"__bool__", std::string("str_bool")} // C.str_bool
|
||||
|
@ -191,6 +188,15 @@ MethodMap &GetMethodMap() {
|
|||
return method_map;
|
||||
}
|
||||
|
||||
BuiltInTypeMap &GetAttrMap() {
|
||||
static BuiltInTypeMap attr_map = {{kObjectTypeTensorType,
|
||||
{
|
||||
{"shape", std::string("shape_")}, // C.shape_
|
||||
{"dtype", std::string("dtype_")}, // C.dtype_
|
||||
}}};
|
||||
return attr_map;
|
||||
}
|
||||
|
||||
Resource::Resource(const py::object &obj)
|
||||
: engine_(std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager_)),
|
||||
input_(obj),
|
||||
|
@ -218,31 +224,42 @@ Resource::~Resource() {
|
|||
}
|
||||
}
|
||||
|
||||
bool Resource::IsTypeInMethodMap(const TypeId &type) {
|
||||
TypeId type_id = NormalizeTypeId(type);
|
||||
const MethodMap &method_map = GetMethodMap();
|
||||
auto iter = method_map.find(static_cast<int>(type_id));
|
||||
if (iter != method_map.end()) {
|
||||
return true;
|
||||
Any GetMethodOrAttr(const string &name, const TypeId &type_id, const BuiltInTypeMap &method_map) {
|
||||
auto type_method_map = method_map.find(static_cast<int>(type_id));
|
||||
if (type_method_map == method_map.end()) {
|
||||
return Any();
|
||||
}
|
||||
return false;
|
||||
auto method = type_method_map->second.find(name);
|
||||
if (method == type_method_map->second.end()) {
|
||||
return Any();
|
||||
}
|
||||
return method->second;
|
||||
}
|
||||
|
||||
bool Resource::IsTypeInBuiltInMap(const TypeId &type) {
|
||||
TypeId type_id = NormalizeTypeId(type);
|
||||
const BuiltInTypeMap &method_map = GetMethodMap();
|
||||
auto iter = method_map.find(static_cast<int>(type_id));
|
||||
if (iter == method_map.end()) {
|
||||
const BuiltInTypeMap &attr_map = GetAttrMap();
|
||||
iter = attr_map.find(static_cast<int>(type_id));
|
||||
if (iter == attr_map.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) {
|
||||
TypeId type_id = NormalizeTypeId(type);
|
||||
const MethodMap &method_map = GetMethodMap();
|
||||
auto iter = method_map.find(static_cast<int>(type_id));
|
||||
if (iter == method_map.end()) {
|
||||
MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map";
|
||||
return Any();
|
||||
}
|
||||
const BuiltInTypeMap &method_map = GetMethodMap();
|
||||
return GetMethodOrAttr(name, type_id, method_map);
|
||||
}
|
||||
|
||||
auto iter_map = iter->second.find(name);
|
||||
if (iter_map == iter->second.end()) {
|
||||
MS_LOG(WARNING) << "Object type: " << type_id << " have no method: " << name;
|
||||
return Any();
|
||||
}
|
||||
return iter_map->second;
|
||||
Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
|
||||
TypeId type_id = NormalizeTypeId(type);
|
||||
const BuiltInTypeMap &attr_map = GetAttrMap();
|
||||
return GetMethodOrAttr(name, type_id, attr_map);
|
||||
}
|
||||
|
||||
void Resource::Clean() {
|
||||
|
|
|
@ -44,9 +44,11 @@ const char kOutput[] = "output";
|
|||
|
||||
class InferenceResource;
|
||||
|
||||
using MethodMap = std::unordered_map<int, std::unordered_map<std::string, Any>>;
|
||||
using BuiltInTypeMap = std::unordered_map<int, std::unordered_map<std::string, Any>>;
|
||||
|
||||
MethodMap &GetMethodMap();
|
||||
BuiltInTypeMap &GetMethodMap();
|
||||
|
||||
BuiltInTypeMap &GetAttrMap();
|
||||
|
||||
class ResourceBase {
|
||||
public:
|
||||
|
@ -87,10 +89,12 @@ class Resource : public ResourceBase {
|
|||
|
||||
abstract::AnalysisEnginePtr engine() { return engine_; }
|
||||
|
||||
static bool IsTypeInMethodMap(const TypeId &type);
|
||||
static bool IsTypeInBuiltInMap(const TypeId &type);
|
||||
|
||||
static Any GetMethodPtr(const TypeId &type, const std::string &name);
|
||||
|
||||
static Any GetAttrPtr(const TypeId &type, const std::string &name);
|
||||
|
||||
const py::object &input() const { return input_; }
|
||||
|
||||
FuncGraphPtr func_graph() const { return func_graph_; }
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
|
@ -31,10 +30,8 @@
|
|||
#include "frontend/operator/prim_to_function.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "./common.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "ir/tensor.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
|
@ -64,7 +61,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
|
||||
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
||||
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
|
||||
{prim::kPrimShape, {InferImplShape, true}},
|
||||
{prim::kPrimPack, {InferImplPack, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||
|
@ -634,7 +630,7 @@ EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveIm
|
|||
}
|
||||
|
||||
const int kResolveCaseUserDefineClass = 1;
|
||||
const int kResolveCaseBuildinTypeMethod = 2;
|
||||
const int kResolveCaseBuiltInType = 2;
|
||||
const int kResolveCaseFunction = 3;
|
||||
int GetResolveCase(const TypePtr &data_type) {
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
|
@ -643,8 +639,8 @@ int GetResolveCase(const TypePtr &data_type) {
|
|||
}
|
||||
|
||||
// try method map, if not in method map, the data_type should be External type.
|
||||
if (pipeline::Resource::IsTypeInMethodMap(data_type->type_id())) {
|
||||
return kResolveCaseBuildinTypeMethod;
|
||||
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
|
||||
return kResolveCaseBuiltInType;
|
||||
}
|
||||
|
||||
return kResolveCaseFunction;
|
||||
|
@ -674,8 +670,10 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
|
|||
manager->AddFuncGraph(func_graph);
|
||||
}
|
||||
|
||||
EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &old_conf) {
|
||||
enum REQUIRE_TYPE { ATTR, METHOD };
|
||||
|
||||
EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
|
||||
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
|
||||
MS_EXCEPTION_IF_NULL(old_conf);
|
||||
|
||||
AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
|
||||
|
@ -701,6 +699,9 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
|
|||
MS_EXCEPTION_IF_NULL(old_conf);
|
||||
FuncGraphPtr func_graph = old_conf->node()->func_graph();
|
||||
CNodePtr new_cnode = func_graph->NewCNode(input);
|
||||
if (require_type == REQUIRE_TYPE::ATTR) {
|
||||
new_cnode = func_graph->NewCNode({new_cnode});
|
||||
}
|
||||
AnalysisEnginePtr eng = old_conf->engine();
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context());
|
||||
return eng->ForwardConfig(old_conf, fn_conf);
|
||||
|
@ -781,9 +782,9 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng
|
|||
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
|
||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
|
||||
const TypePtr &data_type, const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(item_v);
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
// The method maybe a Primitive or Composite
|
||||
|
@ -792,22 +793,29 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &eng
|
|||
}
|
||||
|
||||
std::string item_name = item_v->cast<StringImmPtr>()->value();
|
||||
Any method = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
|
||||
if (method.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Object type: " << data_type->ToString() << " has no method: " << item_name;
|
||||
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
|
||||
Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
|
||||
if (require.empty()) {
|
||||
require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
|
||||
if (require.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The object of type: " << data_type->ToString() << " has no method or attr: " << item_name;
|
||||
}
|
||||
require_type = REQUIRE_TYPE::ATTR;
|
||||
}
|
||||
|
||||
ValuePtr converted_v = nullptr;
|
||||
if (method.is<std::string>()) {
|
||||
if (require.is<std::string>()) {
|
||||
// composite registered in standard_method_map go to this branch
|
||||
converted_v = prim::GetPythonOps(method.cast<std::string>());
|
||||
AddToManager(engine, converted_v->cast<FuncGraphPtr>());
|
||||
} else if (method.is<PrimitivePtr>()) {
|
||||
converted_v = method.cast<PrimitivePtr>();
|
||||
converted_v = prim::GetPythonOps(require.cast<std::string>());
|
||||
if (!converted_v->isa<Primitive>()) {
|
||||
AddToManager(engine, converted_v->cast<FuncGraphPtr>());
|
||||
}
|
||||
} else if (require.is<PrimitivePtr>()) {
|
||||
converted_v = require.cast<PrimitivePtr>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from method map, but got " << method.ToString();
|
||||
MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
|
||||
}
|
||||
return StaticGetterInferred(converted_v, data_conf, out_conf);
|
||||
return StaticGetterInferred(converted_v, data_conf, out_conf, require_type);
|
||||
}
|
||||
|
||||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
|
@ -831,8 +839,8 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
int case_v = GetResolveCase(data_type);
|
||||
if (case_v == kResolveCaseUserDefineClass) {
|
||||
return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
|
||||
} else if (case_v == kResolveCaseBuildinTypeMethod) {
|
||||
return GetEvaluatedValueForBuiltinTypeMethod(engine, item_value, data_type, data_conf, out_conf);
|
||||
} else if (case_v == kResolveCaseBuiltInType) {
|
||||
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
|
||||
} else {
|
||||
return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
|
||||
}
|
||||
|
|
|
@ -218,10 +218,6 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -246,8 +242,6 @@ AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const Primitiv
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
|
|
|
@ -22,20 +22,21 @@ import copy
|
|||
import functools
|
||||
import itertools
|
||||
import numbers
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..._checkparam import Validator as validator
|
||||
from ..._checkparam import Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...common.tensor import Tensor
|
||||
from ...common.parameter import Parameter
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from .._utils import get_concat_offset
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..operations.math_ops import _infer_shape_reduce
|
||||
from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import typing
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
from ...common import dtype as mstype
|
||||
from ...common.parameter import Parameter
|
||||
from ...common.tensor import Tensor
|
||||
|
||||
|
||||
class _ScatterOp(PrimitiveWithInfer):
|
||||
|
@ -415,7 +416,7 @@ class Reshape(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class Shape(Primitive):
|
||||
class Shape(PrimitiveWithInfer):
|
||||
"""
|
||||
Returns the shape of input tensor.
|
||||
|
||||
|
@ -436,6 +437,13 @@ class Shape(Primitive):
|
|||
def __init__(self):
|
||||
"""init Shape"""
|
||||
|
||||
def __infer__(self, x):
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
|
||||
out = {'shape': (),
|
||||
'dtype': mstype.tuple_,
|
||||
'value': tuple(x['shape'])}
|
||||
return out
|
||||
|
||||
|
||||
class Squeeze(PrimitiveWithInfer):
|
||||
"""
|
||||
|
|
|
@ -267,11 +267,6 @@ TEST_F(TestOps, BroadCastShapeTest) {
|
|||
ASSERT_EQ(prim->name(), kPrimBroadcastShape->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, ShapeTest) {
|
||||
auto prim = std::make_shared<Primitive>("Shape");
|
||||
ASSERT_EQ(prim->name(), kPrimShape->name());
|
||||
}
|
||||
|
||||
TEST_F(TestOps, ArrayMapTest) {
|
||||
auto prim = std::make_shared<Primitive>("array_map");
|
||||
ASSERT_EQ(prim->name(), kPrimArrayMap->name());
|
||||
|
|
|
@ -36,23 +36,23 @@ class TestResource : public UT::Common {
|
|||
void TearDown() {}
|
||||
};
|
||||
|
||||
TEST_F(TestResource, test_standard_method_map) {
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt8));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt16));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt32));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeInt64));
|
||||
TEST_F(TestResource, test_built_in_type_map) {
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt8));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt16));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt32));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeInt64));
|
||||
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat16));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat32));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeFloat64));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat16));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat32));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeFloat64));
|
||||
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeBool));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kNumberTypeUInt));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTuple));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeList));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInMethodMap(kObjectTypeTensorType));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeBool));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kNumberTypeUInt));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTuple));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeList));
|
||||
ASSERT_TRUE(true == Resource::IsTypeInBuiltInMap(kObjectTypeTensorType));
|
||||
|
||||
MethodMap& map = GetMethodMap();
|
||||
for (auto& iter : map) {
|
||||
|
|
|
@ -467,24 +467,6 @@ TEST_F(TestPrim, test_env_add) {
|
|||
ASSERT_TRUE(*res == *exp);
|
||||
}
|
||||
|
||||
TEST_F(TestPrim, test_shape) {
|
||||
PrimitivePtr shap = std::make_shared<Primitive>("Shape");
|
||||
FuncGraphPtr func_graph = MakeFuncGraph(shap, 1);
|
||||
|
||||
auto a = UTPrimUtils::ArrayFloat64Of({2, 3});
|
||||
|
||||
AbstractBasePtrList args_spec_list = {a};
|
||||
|
||||
AbstractTuplePtr res = dyn_cast<AbstractTuple>(engine_->Run(func_graph, args_spec_list).inferred->abstract());
|
||||
auto ret = res->BuildValue()->cast<ValueTuplePtr>()->value();
|
||||
|
||||
std::vector<ValuePtr> element_list = {MakeValue(2), MakeValue(3)};
|
||||
ASSERT_TRUE(ret.size() == element_list.size());
|
||||
for (int i = 0; i < element_list.size(); i++) {
|
||||
ASSERT_TRUE(*ret[i] == *element_list[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TestPrim, test_relu) {
|
||||
PrimitivePtr relu = prim::kPrimRelu;
|
||||
relu->AddAttr("T", MakeValue(static_cast<int>(kNumberTypeFloat64)));
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test dtype and shape as attr"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_dtype_and_shape_as_attr():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
shape = x.shape
|
||||
dtype = x.dtype
|
||||
return shape, dtype
|
||||
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
ret = net(x)
|
||||
assert ret == ((1, 2, 3), mstype.int32)
|
||||
|
||||
|
||||
def test_dtype_and_shape_as_attr_to_new_tensor():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, value):
|
||||
super(Net, self).__init__()
|
||||
self.fill = P.Fill()
|
||||
self.value = value
|
||||
|
||||
def construct(self, x):
|
||||
dtype = x.dtype
|
||||
shape = x.shape
|
||||
y = self.fill(dtype, shape, self.value)
|
||||
return y
|
||||
|
||||
|
||||
net = Net(2.2)
|
||||
x = Tensor(np.ones([1, 2, 3], np.float32))
|
||||
ret = net(x)
|
||||
assert (ret.asnumpy() == (np.zeros([1, 2, 3], np.float32) + 2.2)).all()
|
||||
|
||||
|
||||
def test_type_not_have_the_attr():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
shape = x.shapes
|
||||
return shape
|
||||
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
net(x)
|
||||
assert "The object of type: Tensor[Int32] has no method or attr: shapes" in str(ex.value)
|
||||
|
||||
|
||||
def test_type_not_have_the_method():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
||||
def construct(self, x):
|
||||
shape = x.dtypes()
|
||||
return shape
|
||||
|
||||
|
||||
net = Net()
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError) as ex:
|
||||
net(x)
|
||||
assert "The object of type: Tensor[Int32] has no method or attr: dtypes" in str(ex.value)
|
|
@ -20,7 +20,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class FatherNet(nn.Cell):
|
||||
|
@ -92,7 +92,6 @@ class Net(nn.Cell):
|
|||
|
||||
def test_single_super():
|
||||
single_net = SingleSubNet(2, 3)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
single_net(x, y)
|
||||
|
@ -100,7 +99,6 @@ def test_single_super():
|
|||
|
||||
def test_mul_super():
|
||||
mul_net = MulSubNet(2, 3, 4)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
mul_net(x, y)
|
||||
|
@ -108,7 +106,6 @@ def test_mul_super():
|
|||
|
||||
def test_super_cell():
|
||||
net = Net(2)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
with pytest.raises(RuntimeError) as er:
|
||||
|
@ -142,7 +139,6 @@ def test_single_super_in():
|
|||
return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z
|
||||
|
||||
single_net_in = SingleSubNetIN(2, 3)
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
x = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
y = Tensor(np.ones([1, 2, 3], np.int32))
|
||||
single_net_in(x, y)
|
||||
|
|
Loading…
Reference in New Issue