forked from mindspore-Ecosystem/mindspore
!4255 unify primitive
Merge pull request !4255 from lianliguang/unify-primitive
This commit is contained in:
commit
468e797e97
|
@ -393,40 +393,5 @@ ValuePtr BoolEq(const ValuePtrList &list) {
|
|||
|
||||
MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << ".";
|
||||
}
|
||||
|
||||
std::vector<int> BroadcastShape_(std::vector<int> shpx, std::vector<int> shpy) {
|
||||
int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
|
||||
if (dlen < 0) {
|
||||
for (int i = 0; i < -dlen; ++i) {
|
||||
(void)shpx.insert(shpx.begin(), 1);
|
||||
}
|
||||
} else if (dlen > 0) {
|
||||
for (int i = 0; i < dlen; i++) {
|
||||
(void)shpy.insert(shpy.begin(), 1);
|
||||
}
|
||||
}
|
||||
if (shpx.size() != shpy.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
|
||||
}
|
||||
std::vector<int> shp;
|
||||
for (size_t i = 0; i < shpx.size(); i++) {
|
||||
auto a = shpx[i];
|
||||
auto b = shpy[i];
|
||||
if (a == 1) {
|
||||
shp.push_back(b);
|
||||
} else if (b == 1) {
|
||||
shp.push_back(a);
|
||||
} else if (a == -1) {
|
||||
shp.push_back(b);
|
||||
} else if (b == -1) {
|
||||
shp.push_back(a);
|
||||
} else if (a == b) {
|
||||
shp.push_back(a);
|
||||
} else {
|
||||
return std::vector<int>();
|
||||
}
|
||||
}
|
||||
return shp;
|
||||
}
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -52,7 +52,6 @@ ValuePtr BoolNot(const ValuePtrList &list);
|
|||
ValuePtr BoolAnd(const ValuePtrList &list);
|
||||
ValuePtr BoolOr(const ValuePtrList &list);
|
||||
ValuePtr BoolEq(const ValuePtrList &list);
|
||||
std::vector<int> BroadcastShape_(std::vector<int> s1, std::vector<int> s2);
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -42,28 +42,13 @@ inline const PrimitivePtr kPrimRefToEmbed = std::make_shared<Primitive>("RefToEm
|
|||
inline const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_instance");
|
||||
|
||||
// Other miscellaneous
|
||||
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
|
||||
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
|
||||
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
|
||||
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
|
||||
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
|
||||
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
|
||||
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
|
||||
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
||||
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
|
||||
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
|
||||
|
||||
// Structures
|
||||
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
|
||||
inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
|
||||
inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
|
||||
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
|
||||
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
|
||||
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
|
||||
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
|
||||
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
|
||||
|
||||
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
|
||||
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -15,360 +13,266 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "frontend/operator/ops_front_infer_function.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/tensor_py.h"
|
||||
|
||||
using mindspore::tensor::TensorPy;
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
enum State {
|
||||
SAME,
|
||||
X_ONE,
|
||||
Y_ONE,
|
||||
};
|
||||
|
||||
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
struct SlideInfo {
|
||||
int start;
|
||||
int step;
|
||||
int stop;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples or two lists.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
ValuePtr x_value = input_x->BuildValue();
|
||||
ValuePtr y_value = input_y->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(*x_value == *y_value);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractTuple>(args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractList>(args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
size_t keys_size = keys->size();
|
||||
if (values->size() != keys_size) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
|
||||
}
|
||||
|
||||
std::vector<AbstractAttribute> key_value;
|
||||
AbstractScalarPtr key;
|
||||
AbstractBasePtrList key_list = keys->elements();
|
||||
AbstractBasePtrList value_list = values->elements();
|
||||
for (size_t index = 0; index < keys_size; index++) {
|
||||
key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
|
||||
ValuePtr keyPtr = key->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(keyPtr);
|
||||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
|
||||
void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
|
||||
int arg1 = 0;
|
||||
int arg2 = 0;
|
||||
if (!args_spec_list.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
auto arg_value = args_spec_list[0]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
key_value.emplace_back(key_string, value_list[index]);
|
||||
arg1 = GetValue<int>(arg_value);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a string and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
|
||||
ValuePtr keyPtr = key->BuildValue();
|
||||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a string and a keyword.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_input = GetValue<std::string>(key_value);
|
||||
std::string key_actual = kwarg->get_key();
|
||||
if (key_actual != key_input) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
|
||||
<< key_input << ", AbstractKeywordArg' key is " << key_actual;
|
||||
}
|
||||
return kwarg->get_arg();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three scalars whose value is an int32 number.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 3);
|
||||
size_t args_size = args_spec_list.size();
|
||||
for (size_t index = 0; index < args_size; index++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||
if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
|
||||
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
|
||||
if (args_spec_list.size() >= 2) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
auto arg_value = args_spec_list[1]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
if (args_spec_list[index]->isa<AbstractScalar>() &&
|
||||
!dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index
|
||||
<< " parameter is an AbstractScalar, but is not an int32 number.";
|
||||
arg2 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 3) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
auto arg_value = args_spec_list[2]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
slide->step = GetValue<int>(arg_value);
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 2) {
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 1) {
|
||||
slide->stop = arg1;
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y,
|
||||
std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) {
|
||||
const size_t n = reverse_x.size();
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
State curr;
|
||||
const int32_t x_i = reverse_x[i];
|
||||
const int32_t y_i = reverse_y[i];
|
||||
const int reduce_idx = SizeToInt(n - 1 - i);
|
||||
if (x_i == y_i) {
|
||||
curr = SAME;
|
||||
} else if (x_i == 1) {
|
||||
grad_x_reduce_idx->push_back(reduce_idx);
|
||||
curr = X_ONE;
|
||||
} else if (y_i == 1) {
|
||||
grad_y_reduce_idy->push_back(reduce_idx);
|
||||
curr = Y_ONE;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs";
|
||||
}
|
||||
if (curr == SAME && x_i == 1) {
|
||||
grad_x_reduce_idx->push_back(reduce_idx);
|
||||
grad_y_reduce_idy->push_back(reduce_idx);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Slice: start, end, step
|
||||
return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
|
||||
std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
|
||||
std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
|
||||
}
|
||||
|
||||
// Eval the return type of make_record
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: at lease two objects of a subclass of AbstractBase.
|
||||
if (args_spec_list.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
|
||||
<< args_spec_list.size() << ".";
|
||||
AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
|
||||
std::vector<int> reverse_x;
|
||||
std::vector<int> reverse_y;
|
||||
|
||||
(void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x),
|
||||
[](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
|
||||
(void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y),
|
||||
[](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
|
||||
|
||||
if (reverse_x.size() > reverse_y.size()) {
|
||||
reverse_y.resize(reverse_x.size(), 1);
|
||||
} else {
|
||||
reverse_x.resize(reverse_y.size(), 1);
|
||||
}
|
||||
|
||||
// args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
|
||||
std::vector<int> grad_x_reduce_idx;
|
||||
std::vector<int> grad_y_reduce_idy;
|
||||
ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
|
||||
|
||||
AbstractBasePtrList abs_list_x;
|
||||
AbstractBasePtrList abs_list_y;
|
||||
(void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x),
|
||||
[](int v) { return abstract::FromValue(v); });
|
||||
(void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y),
|
||||
[](int v) { return abstract::FromValue(v); });
|
||||
auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x);
|
||||
auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y);
|
||||
AbstractBasePtrList elem_list;
|
||||
elem_list.push_back(x_reduce_idx);
|
||||
elem_list.push_back(y_reduce_idx);
|
||||
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a pointer to an AbstractBase object
|
||||
if (args_spec_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size()
|
||||
<< ".";
|
||||
}
|
||||
AbstractBasePtr abs_base = args_spec_list[0];
|
||||
MS_EXCEPTION_IF_NULL(abs_base);
|
||||
TypePtr type = abs_base->BuildType();
|
||||
return std::make_shared<AbstractType>(type);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a pointer to an AbstractBase object and a pointer to a Type
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
|
||||
|
||||
auto mode_v = abs_type->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(mode_v);
|
||||
if (!mode_v->isa<Type>()) {
|
||||
MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed.";
|
||||
}
|
||||
|
||||
TypePtr mode_t = mode_v->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kMetaTypeTypeType) {
|
||||
MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
|
||||
}
|
||||
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
TypePtr type_ptr = value_track->cast<TypePtr>();
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
|
||||
}
|
||||
|
||||
auto cls = dyn_cast<Class>(type_ptr);
|
||||
MS_EXCEPTION_IF_NULL(cls);
|
||||
ClassAttrVector attributes = cls->GetAttributes();
|
||||
CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
|
||||
|
||||
std::vector<AbstractAttribute> abs_attributes;
|
||||
for (size_t i = 0; i < attributes.size(); i++) {
|
||||
AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
|
||||
abs_attributes.push_back(elem);
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||
bool v = IsSubtype(args_spec_list[0], mode_t);
|
||||
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list and a scalar whose value is an int32 number.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
|
||||
if (x_shape.size() != y_shape.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
|
||||
// and continue
|
||||
if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
|
||||
return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
if (GetValue<int>(x_shape[i]) != GetValue<int>(y_shape[i])) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
int idx_v = GetValue<int>(index_value);
|
||||
std::size_t nelems = queue->elements().size();
|
||||
if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
|
||||
<< SizeToInt(nelems) << "), but got " << idx_v << ".";
|
||||
}
|
||||
|
||||
std::size_t uidx_v = 0;
|
||||
if (idx_v >= 0) {
|
||||
uidx_v = IntToSize(idx_v);
|
||||
} else {
|
||||
uidx_v = IntToSize(idx_v + SizeToInt(nelems));
|
||||
}
|
||||
return queue->elements()[uidx_v];
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
int idx_v = GetValue<int>(index_value);
|
||||
if (idx_v < 0) {
|
||||
MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
|
||||
<< ".";
|
||||
AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
|
||||
const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
|
||||
size_t x_rank = x_shape->size();
|
||||
std::set<int> axis_set;
|
||||
auto axis_data = axis_value_ptr->value();
|
||||
if (axis_data.empty()) {
|
||||
int size = 1;
|
||||
AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
size_t uidx_v = IntToSize(idx_v);
|
||||
AbstractBasePtrList elements = queue->elements();
|
||||
std::size_t nelems = elements.size();
|
||||
if (uidx_v >= nelems) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
|
||||
<< ".";
|
||||
for (auto &elem : axis_data) {
|
||||
int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
|
||||
(void)axis_set.insert(e_value);
|
||||
}
|
||||
elements[uidx_v] = args_spec_list[2];
|
||||
return std::make_shared<T>(elements);
|
||||
|
||||
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
|
||||
if (x_shp_data.size() < x_rank) {
|
||||
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
|
||||
}
|
||||
AbstractBasePtrList values;
|
||||
for (size_t i = 0; i < x_rank; i++) {
|
||||
if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
|
||||
auto axis_v = MakeValue(1);
|
||||
values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
|
||||
} else {
|
||||
int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
|
||||
auto dim = MakeValue(dim_value);
|
||||
values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string.
|
||||
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// this primitive get the index that need to reduce
|
||||
// input: x's shape and y's shape, inputs should be tuple
|
||||
// output: tuple of x and y 's reduce index, reduce index should be a tuple
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
auto arg_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
auto arg_y = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(arg_x_value);
|
||||
|
||||
ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(arg_y_value);
|
||||
|
||||
const std::vector<ValuePtr> x_shape = arg_x_value->value();
|
||||
const std::vector<ValuePtr> y_shape = arg_y_value->value();
|
||||
bool is_same_shape = CompareShape(x_shape, y_shape);
|
||||
// if it is the same shape , do not need reduce , return empty tuple
|
||||
if (is_same_shape) {
|
||||
AbstractBasePtrList empty_list;
|
||||
auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
|
||||
auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
|
||||
|
||||
AbstractBasePtrList elem_list;
|
||||
elem_list.push_back(x_reduce_idx);
|
||||
elem_list.push_back(y_reduce_idx);
|
||||
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
auto key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
||||
if (it == dict_elems.end()) {
|
||||
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
auto new_ele = std::make_pair(key_str, args_spec_list[2]);
|
||||
if (it != dict_elems.end()) {
|
||||
int index = it - dict_elems.begin();
|
||||
dict_elems[IntToSize(index)] = new_ele;
|
||||
} else {
|
||||
dict_elems.push_back(new_ele);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(dict_elems);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a list and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
|
||||
(void)AbstractJoin(list->elements());
|
||||
return list;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list or dict.
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
|
||||
return BroadcastGradientArgsDiff(x_shape, y_shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
|
@ -430,41 +334,6 @@ AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const Primitiv
|
|||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
|
||||
const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
|
||||
size_t x_rank = x_shape->size();
|
||||
std::set<int> axis_set;
|
||||
auto axis_data = axis_value_ptr->value();
|
||||
if (axis_data.empty()) {
|
||||
int size = 1;
|
||||
AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
for (auto &elem : axis_data) {
|
||||
int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
|
||||
(void)axis_set.insert(e_value);
|
||||
}
|
||||
|
||||
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
|
||||
if (x_shp_data.size() < x_rank) {
|
||||
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
|
||||
}
|
||||
AbstractBasePtrList values;
|
||||
for (size_t i = 0; i < x_rank; i++) {
|
||||
if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
|
||||
auto axis_v = MakeValue(1);
|
||||
values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
|
||||
} else {
|
||||
int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
|
||||
auto dim = MakeValue(dim_value);
|
||||
values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: x_shape, axis
|
||||
|
@ -563,7 +432,7 @@ AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitiveP
|
|||
|
||||
py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
|
||||
py::array data = py::array(data_tuple);
|
||||
auto tensor = TensorPy::MakeTensor(data);
|
||||
auto tensor = tensor::TensorPy::MakeTensor(data);
|
||||
auto ret = tensor->ToAbstract();
|
||||
ret->set_value(tensor);
|
||||
MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
|
||||
|
@ -596,76 +465,6 @@ AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return std::make_shared<AbstractScalar>(result_v, result_v->type());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples or two lists.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr x_value = input_x->BuildValue();
|
||||
ValuePtr y_value = input_y->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(*x_value == *y_value);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
struct SlideInfo {
|
||||
int start;
|
||||
int step;
|
||||
int stop;
|
||||
};
|
||||
|
||||
void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
|
||||
int arg1 = 0;
|
||||
int arg2 = 0;
|
||||
if (!args_spec_list.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
auto arg_value = args_spec_list[0]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
arg1 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
if (args_spec_list.size() >= 2) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
auto arg_value = args_spec_list[1]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
arg2 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 3) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
auto arg_value = args_spec_list[2]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
slide->step = GetValue<int>(arg_value);
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 2) {
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 1) {
|
||||
slide->stop = arg1;
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
if (args_spec_list.empty()) {
|
||||
|
@ -709,5 +508,145 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive
|
|||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
return args_spec_list[0]->Clone();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: An object of AbstractFunction.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString();
|
||||
|
||||
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
|
||||
if (x == nullptr) {
|
||||
return std::make_shared<AbstractJTagged>(args_spec_list[0]);
|
||||
}
|
||||
|
||||
AbstractFuncAtomPtrList jv;
|
||||
auto build_jv = [&jv](const AbstractFuncAtomPtr &func) {
|
||||
auto j_closure = std::make_shared<JTransformedAbstractClosure>(func);
|
||||
jv.push_back(j_closure);
|
||||
};
|
||||
x->Visit(build_jv);
|
||||
|
||||
return AbstractFunction::MakeAbstractFunction(jv);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
// Eval the return type of make_record
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: at lease two objects of a subclass of AbstractBase.
|
||||
if (args_spec_list.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
|
||||
<< args_spec_list.size() << ".";
|
||||
}
|
||||
|
||||
// args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kMetaTypeTypeType) {
|
||||
MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
|
||||
}
|
||||
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
TypePtr type_ptr = value_track->cast<TypePtr>();
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
|
||||
}
|
||||
|
||||
auto cls = dyn_cast<Class>(type_ptr);
|
||||
MS_EXCEPTION_IF_NULL(cls);
|
||||
ClassAttrVector attributes = cls->GetAttributes();
|
||||
CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
|
||||
|
||||
std::vector<AbstractAttribute> abs_attributes;
|
||||
for (size_t i = 0; i < attributes.size(); i++) {
|
||||
AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
|
||||
abs_attributes.push_back(elem);
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||
}
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListMap, prim::kPrimListMap, InferImplListMap);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListReduce, prim::kPrimListReduce, InferImplListReduce);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleReversed, prim::kPrimTupleReversed, InferImplTupleReversed);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ReducedShape, prim::kPrimReducedShape, InferImplReduceShape);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleDiv, prim::kPrimTupleDiv, InferImplTupleDiv);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleToArray, prim::kPrimTupleToArray, InferImplTuple2Array);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ShapeMul, prim::kPrimShapeMul, InferImplShapeMul);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TupleEqual, prim::kPrimTupleEqual, InferImplTupleEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(ListEqual, prim::kPrimListEqual, InferImplListEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRange, prim::kPrimMakeRange, InferImplMakeRange);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StopGradient, prim::kPrimStopGradient, InferImplStopGradient);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringEqual, prim::kPrimStringEqual, InferImplStringEqual);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(StringConcat, prim::kPrimStringConcat, InferImplStringConcat);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(DictLen, prim::kPrimDictLen, InferImplDictLen);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(FakeBprop, prim::kPrimFakeBprop, InferImplFakeBprop);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ);
|
||||
REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs,
|
||||
InferImplBroadcastGradientArgs);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
class RegisterFrontendPrimitiveEvalHelper {
|
||||
public:
|
||||
RegisterFrontendPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
|
||||
const StandardPrimitiveImplReg impl_reg{impl, false};
|
||||
RegisterStandardPrimitiveImpl(primitive, impl_reg);
|
||||
}
|
||||
~RegisterFrontendPrimitiveEvalHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
|
||||
static auto helper_##name = RegisterFrontendPrimitiveEvalHelper(primitive, impl)
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
|
|
@ -36,115 +36,12 @@
|
|||
#include "utils/convert_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||
static PrimitiveEvalImplMap prim_eval_implement_map = {
|
||||
// Statements
|
||||
{prim::kPrimReturn, {InferImplReturn, true}},
|
||||
{prim::kPrimTypeOf, {InferImplTypeof, false}},
|
||||
{prim::kPrimHasType, {InferImplHasType, false}},
|
||||
{prim::kPrimDot, {InferImplDot, true}},
|
||||
{prim::kPrimSwitch, {InferImplSwitch, true}},
|
||||
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
|
||||
{prim::kPrimIs_, {InferImplIs_, true}},
|
||||
{prim::kPrimIsNot, {InferImplIsNot, true}},
|
||||
{prim::kPrimInDict, {InferImplInDict, true}},
|
||||
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
|
||||
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
|
||||
// Maths
|
||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
// Array
|
||||
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
|
||||
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
||||
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
|
||||
{prim::kPrimPack, {InferImplPack, true}},
|
||||
{prim::kPrimUnique, {InferImplUnique, true}},
|
||||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
||||
{prim::kPrimMakeDict, {InferImplMakeDict, true}},
|
||||
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
|
||||
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
|
||||
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
|
||||
{prim::kPrimMakeRecord, {InferImplMakeRecord, false}},
|
||||
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
|
||||
{prim::kPrimListGetItem, {InferImplListGetItem, true}},
|
||||
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
|
||||
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
|
||||
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
|
||||
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
|
||||
{prim::kPrimListAppend, {InferImplListAppend, true}},
|
||||
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
|
||||
{prim::kPrimListLen, {InferImplListLen, true}},
|
||||
{prim::kPrimArrayLen, {InferImplArrayLen, true}},
|
||||
{prim::kPrimListMap, {InferImplListMap, false}},
|
||||
{prim::kPrimListReduce, {InferImplListReduce, false}},
|
||||
{prim::kPrimTupleReversed, {InferImplTupleReversed, false}},
|
||||
{prim::kPrimReducedShape, {InferImplReduceShape, false}},
|
||||
{prim::kPrimTupleDiv, {InferImplTupleDiv, false}},
|
||||
{prim::kPrimTupleToArray, {InferImplTuple2Array, false}},
|
||||
{prim::kPrimShapeMul, {InferImplShapeMul, false}},
|
||||
{prim::kPrimTupleEqual, {InferImplTupleEqual, false}},
|
||||
{prim::kPrimListEqual, {InferImplListEqual, false}},
|
||||
{prim::kPrimMakeRange, {InferImplMakeRange, false}},
|
||||
{prim::kPrimStopGradient, {InferImplStopGradient, false}},
|
||||
{prim::kPrimStringEqual, {InferImplStringEqual, false}},
|
||||
{prim::kPrimStringConcat, {InferImplStringConcat, false}},
|
||||
{prim::kPrimDictLen, {InferImplDictLen, false}},
|
||||
// NN
|
||||
{prim::kPrimPooling, {InferImplPooling, true}},
|
||||
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
|
||||
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
|
||||
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
|
||||
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
|
||||
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
|
||||
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
|
||||
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
|
||||
{prim::kPrimRelu, {InferImplRelu, true}},
|
||||
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
|
||||
{prim::kPrimZerosLike, {InferImplZerosLike, true}},
|
||||
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
|
||||
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
|
||||
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
|
||||
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
|
||||
// Others
|
||||
{prim::kPrimIdentity, {InferImplIdentity, true}},
|
||||
// Set impl to null as it will use PartialEvaluator;
|
||||
{prim::kPrimPartial, {nullptr, true}},
|
||||
{prim::kPrimJ, {InferImplJ, false}},
|
||||
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
|
||||
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
|
||||
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
|
||||
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
|
||||
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
|
||||
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
|
||||
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
|
||||
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
|
||||
{prim::kPrimDepend, {InferImplDepend, true}},
|
||||
{prim::kPrimBroadcastGradientArgs, {InferImplBroadcastGradientArgs, false}},
|
||||
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
||||
// Debug
|
||||
{prim::kPrimDebug, {InferImplDebug, true}},
|
||||
// RowTensor
|
||||
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}},
|
||||
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}},
|
||||
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}},
|
||||
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
|
||||
// SparseTensor
|
||||
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
|
||||
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
|
||||
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
|
||||
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
|
||||
};
|
||||
return prim_eval_implement_map;
|
||||
}
|
||||
|
||||
using mindspore::parse::PyObjectWrapper;
|
||||
|
||||
std::unordered_set<std::string> prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem",
|
||||
|
|
|
@ -26,19 +26,10 @@
|
|||
#include <vector>
|
||||
|
||||
#include "pipeline/jit/static_analysis/evaluator.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &);
|
||||
struct StandartPrimitiveImplReg {
|
||||
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive.
|
||||
bool in_white_list_; // true if this Primitive in white list, else false.
|
||||
};
|
||||
|
||||
using PrimitiveEvalImplMap =
|
||||
std::unordered_map<PrimitivePtr, StandartPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>;
|
||||
|
||||
class StandardPrimEvaluator : public TrivialPrimEvaluator {
|
||||
public:
|
||||
StandardPrimEvaluator(const PrimitivePtr primitive, StandardPrimitiveEvalImpl eval_impl)
|
||||
|
@ -179,191 +170,6 @@ bool IsSubtype(const AbstractBasePtr x, const TypePtr model);
|
|||
void ClearPrimEvaluatorMap();
|
||||
|
||||
py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base);
|
||||
|
||||
AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
|
||||
#define MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "base/core_ops.h"
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsConstant(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPoolingGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLayerNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplArrayToScalar(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplPack(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGenShapeIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGenInverseIndex(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplEnvAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGetRefKey(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGetRefValue(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list or dict.
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_INFER_FUNCTIONS_H_
|
|
@ -14,13 +14,48 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "frontend/operator/cc_implementations.h"
|
||||
#include "abstract/param_validator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
namespace {
|
||||
std::vector<int> BroadcastShape(std::vector<int> shpx, std::vector<int> shpy) {
|
||||
int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
|
||||
if (dlen < 0) {
|
||||
for (int i = 0; i < -dlen; ++i) {
|
||||
(void)shpx.insert(shpx.begin(), 1);
|
||||
}
|
||||
} else if (dlen > 0) {
|
||||
for (int i = 0; i < dlen; i++) {
|
||||
(void)shpy.insert(shpy.begin(), 1);
|
||||
}
|
||||
}
|
||||
if (shpx.size() != shpy.size()) {
|
||||
MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
|
||||
}
|
||||
std::vector<int> shp;
|
||||
for (size_t i = 0; i < shpx.size(); i++) {
|
||||
auto a = shpx[i];
|
||||
auto b = shpy[i];
|
||||
if (a == 1) {
|
||||
shp.push_back(b);
|
||||
} else if (b == 1) {
|
||||
shp.push_back(a);
|
||||
} else if (a == -1) {
|
||||
shp.push_back(b);
|
||||
} else if (b == -1) {
|
||||
shp.push_back(a);
|
||||
} else if (a == b) {
|
||||
shp.push_back(a);
|
||||
} else {
|
||||
return std::vector<int>();
|
||||
}
|
||||
}
|
||||
return shp;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a scalar.
|
||||
|
@ -65,7 +100,7 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti
|
|||
(void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y),
|
||||
[](const ValuePtr &e) -> int { return GetValue<int>(e); });
|
||||
|
||||
std::vector<int> res = prim::BroadcastShape_(shp_x, shp_y);
|
||||
std::vector<int> res = BroadcastShape(shp_x, shp_y);
|
||||
if (res.empty()) {
|
||||
MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << ","
|
||||
<< args_spec_list[1]->ToString();
|
|
@ -15,8 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "abstract/param_validator.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
|
@ -14,8 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "utils/ms_utils.h"
|
|
@ -14,10 +14,12 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "c_ops/conv2d.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -278,13 +280,6 @@ AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
|
@ -433,5 +428,91 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
|
|||
return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
|
||||
std::make_shared<Shape>(std::vector<int64_t>{shape_y}));
|
||||
}
|
||||
|
||||
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto conv_prim = primitive->cast<PrimConv2dPtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_prim);
|
||||
auto prim_name = conv_prim->name();
|
||||
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name);
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]",
|
||||
w_shape[1], conv_prim->name());
|
||||
auto out_channel = conv_prim->GetOutputChannel();
|
||||
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
|
||||
std::vector<int> temp_w;
|
||||
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
||||
CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w,
|
||||
conv_prim->name());
|
||||
|
||||
auto kernel_size_h = w_shape[2];
|
||||
auto kernel_size_w = w_shape[3];
|
||||
auto stride = conv_prim->GetStride();
|
||||
auto dilation = conv_prim->GetDilation();
|
||||
auto stride_h = stride[2];
|
||||
auto stride_w = stride[3];
|
||||
auto dilation_h = dilation[2];
|
||||
auto dilation_w = dilation[3];
|
||||
int h_out = -1;
|
||||
int w_out = -1;
|
||||
std::vector<int> pad_list(4, 0);
|
||||
auto pad_mode = conv_prim->GetPadMode();
|
||||
if (pad_mode == "valid") {
|
||||
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
|
||||
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
|
||||
} else if (pad_mode == "same") {
|
||||
h_out = ceil(x_shape[2] / stride_h);
|
||||
w_out = ceil(x_shape[3] / stride_w);
|
||||
|
||||
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
|
||||
pad_list.emplace_back(floor(pad_needed_h / 2));
|
||||
pad_list.emplace_back(pad_needed_h / 2);
|
||||
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
|
||||
auto pad_left = floor(pad_needed_w / 2);
|
||||
pad_list.emplace_back(pad_left);
|
||||
pad_list.emplace_back(pad_needed_h - pad_left);
|
||||
} else if (pad_mode == "pad") {
|
||||
std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list));
|
||||
auto pad_top = conv_prim->GetPad()[0];
|
||||
auto pad_bottom = conv_prim->GetPad()[1];
|
||||
auto pad_right = conv_prim->GetPad()[2];
|
||||
auto pad_left = conv_prim->GetPad()[3];
|
||||
|
||||
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
|
||||
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
|
||||
h_out = floor(h_out);
|
||||
w_out = floor(w_out);
|
||||
}
|
||||
conv_prim->SetPadList(pad_list);
|
||||
std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name());
|
||||
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->GetTypeTrack());
|
||||
types.emplace("w", input_args[1]->GetTypeTrack());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
if (x_type == kNumberTypeInt8) {
|
||||
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
|
||||
}
|
||||
return std::make_shared<TensorType>(TypeIdToType(x_type));
|
||||
}
|
||||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
|
||||
Conv2dInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -19,9 +19,9 @@
|
|||
|
||||
#include "ir/dtype.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
@ -35,27 +35,6 @@ AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
return args_spec_list[0];
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: An object of AbstractFunction.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
MS_LOG(DEBUG) << "evaluate J: " << args_spec_list[0]->ToString();
|
||||
|
||||
AbstractFunctionPtr x = dyn_cast<AbstractFunction>(args_spec_list[0]);
|
||||
if (x == nullptr) {
|
||||
return std::make_shared<AbstractJTagged>(args_spec_list[0]);
|
||||
}
|
||||
|
||||
AbstractFuncAtomPtrList jv;
|
||||
auto build_jv = [&jv](const AbstractFuncAtomPtr &func) {
|
||||
auto j_closure = std::make_shared<JTransformedAbstractClosure>(func);
|
||||
jv.push_back(j_closure);
|
||||
};
|
||||
x->Visit(build_jv);
|
||||
|
||||
return AbstractFunction::MakeAbstractFunction(jv);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -196,125 +175,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
return depends;
|
||||
}
|
||||
|
||||
bool CompareShape(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
|
||||
if (x_shape.size() != y_shape.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < x_shape.size(); ++i) {
|
||||
if (GetValue<int>(x_shape[i]) != GetValue<int>(y_shape[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
enum State {
|
||||
SAME,
|
||||
X_ONE,
|
||||
Y_ONE,
|
||||
};
|
||||
|
||||
void ComputeReduceIndex(const std::vector<int> &reverse_x, const std::vector<int> &reverse_y,
|
||||
std::vector<int> *grad_x_reduce_idx, std::vector<int> *grad_y_reduce_idy) {
|
||||
const size_t n = reverse_x.size();
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
State curr;
|
||||
const int32_t x_i = reverse_x[i];
|
||||
const int32_t y_i = reverse_y[i];
|
||||
const int reduce_idx = SizeToInt(n - 1 - i);
|
||||
if (x_i == y_i) {
|
||||
curr = SAME;
|
||||
} else if (x_i == 1) {
|
||||
grad_x_reduce_idx->push_back(reduce_idx);
|
||||
curr = X_ONE;
|
||||
} else if (y_i == 1) {
|
||||
grad_y_reduce_idy->push_back(reduce_idx);
|
||||
curr = Y_ONE;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "not compatible shape input for BroadcastGradientArgs";
|
||||
}
|
||||
if (curr == SAME && x_i == 1) {
|
||||
grad_x_reduce_idx->push_back(reduce_idx);
|
||||
grad_y_reduce_idy->push_back(reduce_idx);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::reverse(grad_x_reduce_idx->begin(), grad_x_reduce_idx->end());
|
||||
std::reverse(grad_y_reduce_idy->begin(), grad_y_reduce_idy->end());
|
||||
}
|
||||
|
||||
AbstractBasePtr BroadcastGradientArgsDiff(const std::vector<ValuePtr> &x_shape, const std::vector<ValuePtr> &y_shape) {
|
||||
std::vector<int> reverse_x;
|
||||
std::vector<int> reverse_y;
|
||||
|
||||
(void)std::transform(x_shape.rbegin(), x_shape.rend(), std::back_inserter(reverse_x),
|
||||
[](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
|
||||
(void)std::transform(y_shape.rbegin(), y_shape.rend(), std::back_inserter(reverse_y),
|
||||
[](const ValuePtr &v) { return v->cast<Int32ImmPtr>()->value(); });
|
||||
|
||||
if (reverse_x.size() > reverse_y.size()) {
|
||||
reverse_y.resize(reverse_x.size(), 1);
|
||||
} else {
|
||||
reverse_x.resize(reverse_y.size(), 1);
|
||||
}
|
||||
|
||||
std::vector<int> grad_x_reduce_idx;
|
||||
std::vector<int> grad_y_reduce_idy;
|
||||
ComputeReduceIndex(reverse_x, reverse_y, &grad_x_reduce_idx, &grad_y_reduce_idy);
|
||||
|
||||
AbstractBasePtrList abs_list_x;
|
||||
AbstractBasePtrList abs_list_y;
|
||||
(void)std::transform(grad_x_reduce_idx.begin(), grad_x_reduce_idx.end(), std::back_inserter(abs_list_x),
|
||||
[](int v) { return abstract::FromValue(v); });
|
||||
(void)std::transform(grad_y_reduce_idy.begin(), grad_y_reduce_idy.end(), std::back_inserter(abs_list_y),
|
||||
[](int v) { return abstract::FromValue(v); });
|
||||
auto x_reduce_idx = std::make_shared<AbstractTuple>(abs_list_x);
|
||||
auto y_reduce_idx = std::make_shared<AbstractTuple>(abs_list_y);
|
||||
AbstractBasePtrList elem_list;
|
||||
elem_list.push_back(x_reduce_idx);
|
||||
elem_list.push_back(y_reduce_idx);
|
||||
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// this primitive get the index that need to reduce
|
||||
// input: x's shape and y's shape, inputs should be tuple
|
||||
// output: tuple of x and y 's reduce index, reduce index should be a tuple
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto arg_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
auto arg_y = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
ValueTuplePtr arg_x_value = arg_x->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(arg_x_value);
|
||||
|
||||
ValueTuplePtr arg_y_value = arg_y->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(arg_y_value);
|
||||
|
||||
const std::vector<ValuePtr> x_shape = arg_x_value->value();
|
||||
const std::vector<ValuePtr> y_shape = arg_y_value->value();
|
||||
bool is_same_shape = CompareShape(x_shape, y_shape);
|
||||
// if it is the same shape , do not need reduce , return empty tuple
|
||||
if (is_same_shape) {
|
||||
AbstractBasePtrList empty_list;
|
||||
auto x_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
|
||||
auto y_reduce_idx = std::make_shared<AbstractTuple>(empty_list);
|
||||
|
||||
AbstractBasePtrList elem_list;
|
||||
elem_list.push_back(x_reduce_idx);
|
||||
elem_list.push_back(y_reduce_idx);
|
||||
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
return BroadcastGradientArgsDiff(x_shape, y_shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: Two objects of a subclass of AbstractBase
|
|
@ -15,8 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "abstract/param_validator.h"
|
||||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
|
||||
|
@ -34,38 +33,6 @@ AbstractBasePtr InferImplReturn(const AnalysisEnginePtr &, const PrimitivePtr &,
|
|||
return abs_base;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a pointer to an AbstractBase object
|
||||
if (args_spec_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Typeof evaluator requires 1 parameter, while the input size is " << args_spec_list.size()
|
||||
<< ".";
|
||||
}
|
||||
AbstractBasePtr abs_base = args_spec_list[0];
|
||||
MS_EXCEPTION_IF_NULL(abs_base);
|
||||
TypePtr type = abs_base->BuildType();
|
||||
return std::make_shared<AbstractType>(type);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a pointer to an AbstractBase object and a pointer to a Type
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTypePtr abs_type = CheckArg<AbstractType>(op_name, args_spec_list, 1);
|
||||
|
||||
auto mode_v = abs_type->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(mode_v);
|
||||
if (!mode_v->isa<Type>()) {
|
||||
MS_LOG(EXCEPTION) << "Get the type from AbstractType value failed.";
|
||||
}
|
||||
|
||||
TypePtr mode_t = mode_v->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
bool v = IsSubtype(args_spec_list[0], mode_t);
|
||||
return std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(v), kBool);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors.
|
|
@ -0,0 +1,278 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "abstract/param_validator.h"
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractTuple>(args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractList>(args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
size_t keys_size = keys->size();
|
||||
if (values->size() != keys_size) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
|
||||
}
|
||||
|
||||
std::vector<AbstractAttribute> key_value;
|
||||
AbstractScalarPtr key;
|
||||
AbstractBasePtrList key_list = keys->elements();
|
||||
AbstractBasePtrList value_list = values->elements();
|
||||
for (size_t index = 0; index < keys_size; index++) {
|
||||
key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
|
||||
ValuePtr keyPtr = key->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(keyPtr);
|
||||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
key_value.emplace_back(key_string, value_list[index]);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a string and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
|
||||
ValuePtr keyPtr = key->BuildValue();
|
||||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a string and a keyword.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_input = GetValue<std::string>(key_value);
|
||||
std::string key_actual = kwarg->get_key();
|
||||
if (key_actual != key_input) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
|
||||
<< key_input << ", AbstractKeywordArg' key is " << key_actual;
|
||||
}
|
||||
return kwarg->get_arg();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three scalars whose value is an int32 number.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 3);
|
||||
size_t args_size = args_spec_list.size();
|
||||
for (size_t index = 0; index < args_size; index++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||
if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
|
||||
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
|
||||
}
|
||||
if (args_spec_list[index]->isa<AbstractScalar>() &&
|
||||
!dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(TypeError) << "MakeSlice eval " << index
|
||||
<< " parameter is an AbstractScalar, but is not an int32 number.";
|
||||
}
|
||||
}
|
||||
// Slice: start, end, step
|
||||
return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list and a scalar whose value is an int32 number.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
// when index_value is an AnyValue and args_spec_list[0] is a scalar, try to return the type of the first element
|
||||
// and continue
|
||||
if (dyn_cast<AbstractScalar>(queue->elements()[0]) != nullptr) {
|
||||
return std::make_shared<AbstractScalar>(queue->elements()[0]->BuildType());
|
||||
}
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
int idx_v = GetValue<int>(index_value);
|
||||
std::size_t nelems = queue->elements().size();
|
||||
if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
|
||||
<< SizeToInt(nelems) << "), but got " << idx_v << ".";
|
||||
}
|
||||
|
||||
std::size_t uidx_v = 0;
|
||||
if (idx_v >= 0) {
|
||||
uidx_v = IntToSize(idx_v);
|
||||
} else {
|
||||
uidx_v = IntToSize(idx_v + SizeToInt(nelems));
|
||||
}
|
||||
return queue->elements()[uidx_v];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
int idx_v = GetValue<int>(index_value);
|
||||
if (idx_v < 0) {
|
||||
MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
|
||||
<< ".";
|
||||
}
|
||||
|
||||
size_t uidx_v = IntToSize(idx_v);
|
||||
AbstractBasePtrList elements = queue->elements();
|
||||
std::size_t nelems = elements.size();
|
||||
if (uidx_v >= nelems) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
|
||||
<< ".";
|
||||
}
|
||||
elements[uidx_v] = args_spec_list[2];
|
||||
return std::make_shared<T>(elements);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
auto key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
||||
if (it == dict_elems.end()) {
|
||||
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
auto new_ele = std::make_pair(key_str, args_spec_list[2]);
|
||||
if (it != dict_elems.end()) {
|
||||
int index = it - dict_elems.begin();
|
||||
dict_elems[IntToSize(index)] = new_ele;
|
||||
} else {
|
||||
dict_elems.push_back(new_ele);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(dict_elems);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a list and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
|
||||
(void)AbstractJoin(list->elements());
|
||||
return list;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||
static PrimitiveEvalImplMap prim_eval_implement_map = {
|
||||
// Statements
|
||||
{prim::kPrimReturn, {InferImplReturn, true}},
|
||||
{prim::kPrimDot, {InferImplDot, true}},
|
||||
{prim::kPrimSwitch, {InferImplSwitch, true}},
|
||||
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
|
||||
{prim::kPrimIs_, {InferImplIs_, true}},
|
||||
{prim::kPrimIsNot, {InferImplIsNot, true}},
|
||||
{prim::kPrimInDict, {InferImplInDict, true}},
|
||||
{prim::kPrimNotInDict, {InferImplNotInDict, true}},
|
||||
{prim::kPrimIsConsant, {InferImplIsConstant, true}},
|
||||
// Maths
|
||||
{prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
{prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}},
|
||||
// Array
|
||||
{prim::kPrimScalarToArray, {InferImplScalarToArray, true}},
|
||||
{prim::kPrimArrayToScalar, {InferImplArrayToScalar, true}},
|
||||
{prim::kPrimBroadcastShape, {InferImplBroadCastShape, true}},
|
||||
{prim::kPrimPack, {InferImplPack, true}},
|
||||
{prim::kPrimUnique, {InferImplUnique, true}},
|
||||
{prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
||||
{prim::kPrimMakeDict, {InferImplMakeDict, true}},
|
||||
{prim::kPrimMakeSlice, {InferImplMakeSlice, true}},
|
||||
{prim::kPrimMakeKeywordArg, {InferImplMakeKwarg, true}},
|
||||
{prim::kPrimExtractKeywordArg, {InferImplExtractKwarg, true}},
|
||||
{prim::kPrimTupleGetItem, {InferImplTupleGetItem, true}},
|
||||
{prim::kPrimListGetItem, {InferImplListGetItem, true}},
|
||||
{prim::kPrimTupleSetItem, {InferImplTupleSetItem, true}},
|
||||
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
|
||||
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
|
||||
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
|
||||
{prim::kPrimListAppend, {InferImplListAppend, true}},
|
||||
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
|
||||
{prim::kPrimListLen, {InferImplListLen, true}},
|
||||
{prim::kPrimArrayLen, {InferImplArrayLen, true}},
|
||||
// NN
|
||||
{prim::kPrimPooling, {InferImplPooling, true}},
|
||||
{prim::kPrimPoolingGrad, {InferImplPoolingGrad, true}},
|
||||
{prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}},
|
||||
{prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}},
|
||||
{prim::kPrimReluGrad, {InferImplReluGrad, true}},
|
||||
{prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}},
|
||||
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
|
||||
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
|
||||
{prim::kPrimRelu, {InferImplRelu, true}},
|
||||
{prim::kPrimZerosLike, {InferImplZerosLike, true}},
|
||||
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
|
||||
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
|
||||
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
|
||||
{prim::kPrimDropoutGenMask, {InferImplDropoutGenMask, true}},
|
||||
// Others
|
||||
{prim::kPrimIdentity, {InferImplIdentity, true}},
|
||||
// Set impl to null as it will use PartialEvaluator;
|
||||
{prim::kPrimPartial, {nullptr, true}},
|
||||
{prim::kPrimEnvGetItem, {InferImplEnvGetItem, true}},
|
||||
{prim::kPrimEnvSetItem, {InferImplEnvSetItem, true}},
|
||||
{prim::kPrimEnvAdd, {InferImplEnvAdd, true}},
|
||||
{prim::kPrimMakeRefKey, {InferImplMakeRefKey, true}},
|
||||
{prim::kPrimMakeRef, {InferImplMakeRef, true}},
|
||||
{prim::kPrimGetRefKey, {InferImplGetRefKey, true}},
|
||||
{prim::kPrimGetRefValue, {InferImplGetRefValue, true}},
|
||||
{prim::kPrimStateSetItem, {InferImplStateSetItem, true}},
|
||||
{prim::kPrimDepend, {InferImplDepend, true}},
|
||||
{prim::kPrimControlDepend, {InferImplControlDepend, true}},
|
||||
// Debug
|
||||
{prim::kPrimDebug, {InferImplDebug, true}},
|
||||
// SparseTensor
|
||||
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
|
||||
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
|
||||
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
|
||||
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
|
||||
// RowTensor
|
||||
{prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}},
|
||||
{prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}},
|
||||
{prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}},
|
||||
{prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
|
||||
};
|
||||
return prim_eval_implement_map;
|
||||
}
|
||||
|
||||
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg) {
|
||||
auto &prim_eval_map = GetPrimitiveToEvalImplMap();
|
||||
prim_eval_map[primitive] = impl_reg;
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
||||
#define MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
||||
#include <unordered_map>
|
||||
#include "ir/primitive.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
using StandardPrimitiveEvalImpl = AbstractBasePtr (*)(const abstract::AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &);
|
||||
struct StandardPrimitiveImplReg {
|
||||
StandardPrimitiveEvalImpl impl_; // Implement function of Primitive.
|
||||
bool in_white_list_; // true if this Primitive in white list, else false.
|
||||
};
|
||||
|
||||
using PrimitiveEvalImplMap =
|
||||
std::unordered_map<PrimitivePtr, StandardPrimitiveImplReg, PrimitiveHasher, PrimitiveEqual>;
|
||||
|
||||
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap();
|
||||
|
||||
void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);
|
||||
|
||||
class RegisterStandardPrimitiveEvalHelper {
|
||||
public:
|
||||
RegisterStandardPrimitiveEvalHelper(const PrimitivePtr &primitive, const StandardPrimitiveEvalImpl &impl) {
|
||||
const StandardPrimitiveImplReg impl_reg{impl, true};
|
||||
RegisterStandardPrimitiveImpl(primitive, impl_reg);
|
||||
}
|
||||
~RegisterStandardPrimitiveEvalHelper() = default;
|
||||
};
|
||||
|
||||
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
|
||||
static auto helper_##name = RegisterStandardPrimitiveEvalHelper(primitive, impl)
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_
|
|
@ -246,6 +246,25 @@ inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_d
|
|||
inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
|
||||
inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
|
||||
|
||||
// Structures
|
||||
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
|
||||
inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
|
||||
inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
|
||||
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
|
||||
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
|
||||
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
|
||||
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
|
||||
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
|
||||
|
||||
// Other miscellaneous
|
||||
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
|
||||
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
|
||||
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
|
||||
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
|
||||
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
|
||||
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
|
||||
|
||||
// Other primitve not used by backend but used in core;
|
||||
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
|
||||
|
|
|
@ -26,87 +26,19 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
using PrimConv2dPtr = std::shared_ptr<Conv2d>;
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto conv_prim = primitive->cast<PrimConv2dPtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_prim);
|
||||
auto prim_name = conv_prim->name();
|
||||
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name);
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]",
|
||||
w_shape[1], conv_prim->name());
|
||||
auto out_channel = conv_prim->GetOutputChannel();
|
||||
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
|
||||
std::vector<int> temp_w;
|
||||
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
||||
CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w,
|
||||
conv_prim->name());
|
||||
|
||||
auto kernel_size_h = w_shape[2];
|
||||
auto kernel_size_w = w_shape[3];
|
||||
auto stride = conv_prim->GetStride();
|
||||
auto dilation = conv_prim->GetDilation();
|
||||
auto stride_h = stride[2];
|
||||
auto stride_w = stride[3];
|
||||
auto dilation_h = dilation[2];
|
||||
auto dilation_w = dilation[3];
|
||||
int h_out = -1;
|
||||
int w_out = -1;
|
||||
std::vector<int> pad_list(4, 0);
|
||||
auto pad_mode = conv_prim->GetPadMode();
|
||||
if (pad_mode == "valid") {
|
||||
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
|
||||
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
|
||||
} else if (pad_mode == "same") {
|
||||
h_out = ceil(x_shape[2] / stride_h);
|
||||
w_out = ceil(x_shape[3] / stride_w);
|
||||
|
||||
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
|
||||
pad_list.emplace_back(floor(pad_needed_h / 2));
|
||||
pad_list.emplace_back(pad_needed_h / 2);
|
||||
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
|
||||
auto pad_left = floor(pad_needed_w / 2);
|
||||
pad_list.emplace_back(pad_left);
|
||||
pad_list.emplace_back(pad_needed_h - pad_left);
|
||||
} else if (pad_mode == "pad") {
|
||||
std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list));
|
||||
auto pad_top = conv_prim->GetPad()[0];
|
||||
auto pad_bottom = conv_prim->GetPad()[1];
|
||||
auto pad_right = conv_prim->GetPad()[2];
|
||||
auto pad_left = conv_prim->GetPad()[3];
|
||||
|
||||
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
|
||||
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
|
||||
h_out = floor(h_out);
|
||||
w_out = floor(w_out);
|
||||
}
|
||||
conv_prim->SetPadList(pad_list);
|
||||
std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name());
|
||||
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->GetTypeTrack());
|
||||
types.emplace("w", input_args[1]->GetTypeTrack());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
if (x_type == kNumberTypeInt8) {
|
||||
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
|
||||
}
|
||||
return std::make_shared<TensorType>(TypeIdToType(x_type));
|
||||
}
|
||||
constexpr auto kKernelSize = "kernel_size";
|
||||
constexpr auto kStride = "stride";
|
||||
constexpr auto kDilation = "dilation";
|
||||
constexpr auto kPadMode = "pad_mode";
|
||||
constexpr auto kPad = "pad";
|
||||
constexpr auto kMode = "mode";
|
||||
constexpr auto kGroup = "group";
|
||||
constexpr auto kOutputChannel = "output channel";
|
||||
constexpr auto kPadList = "pad_list";
|
||||
constexpr auto kConv2DName = "Conv2D";
|
||||
} // namespace
|
||||
Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
|
||||
|
||||
void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode, const std::string &pad_mode,
|
||||
const std::vector<int> &pad, const std::vector<int> &stride, const std::vector<int> &dilation,
|
||||
int group) {
|
||||
|
@ -130,10 +62,47 @@ void Conv2d::Init(int out_channel, const std::vector<int> &kernel_size, int mode
|
|||
this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name));
|
||||
this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
|
||||
}
|
||||
|
||||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
std::vector<int> Conv2d::GetKernelSize() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::vector<int> Conv2d::GetStride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::vector<int> Conv2d::GetDilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::string Conv2d::GetPadMode() const {
|
||||
auto value_ptr = this->GetAttr(kPadMode);
|
||||
return GetValue<string>(value_ptr);
|
||||
}
|
||||
std::vector<int> Conv2d::GetPad() const {
|
||||
auto value_ptr = this->GetAttr(kPad);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
int Conv2d::GetMode() const {
|
||||
auto value_ptr = this->GetAttr(kMode);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
|
||||
int Conv2d::GetGroup() const {
|
||||
auto value_ptr = this->GetAttr(kGroup);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
int Conv2d::GetOutputChannel() const {
|
||||
auto value_ptr = this->GetAttr(kOutputChannel);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
|
||||
void Conv2d::SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); }
|
||||
void Conv2d::SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
void Conv2d::SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
|
||||
void Conv2d::SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
|
||||
void Conv2d::SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
|
||||
void Conv2d::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
|
||||
void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
|
||||
void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
|
||||
void Conv2d::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,79 +16,44 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H
|
||||
#define MINDSPORE_CORE_C_OPS_CONV2D_H
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H_
|
||||
#define MINDSPORE_CORE_C_OPS_CONV2D_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
namespace mindspore {
|
||||
class Conv2d : public PrimitiveC {
|
||||
public:
|
||||
Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
|
||||
Conv2d();
|
||||
void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
|
||||
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
|
||||
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
|
||||
std::vector<int> GetKernelSize() const {
|
||||
auto value_ptr = this->GetAttr(kKernelSize);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::vector<int> GetStride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::vector<int> GetDilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::string GetPadMode() const {
|
||||
auto value_ptr = this->GetAttr(kPadMode);
|
||||
return GetValue<string>(value_ptr);
|
||||
}
|
||||
std::vector<int> GetPad() const {
|
||||
auto value_ptr = this->GetAttr(kPad);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
int GetMode() const {
|
||||
auto value_ptr = this->GetAttr(kMode);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
|
||||
int GetGroup() const {
|
||||
auto value_ptr = this->GetAttr(kGroup);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
int GetOutputChannel() const {
|
||||
auto value_ptr = this->GetAttr(kOutputChannel);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
|
||||
void SetKernelSize(const std::vector<int> &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); }
|
||||
void SetStride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
void SetDilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
|
||||
void SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
|
||||
void SetPad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
|
||||
void SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
|
||||
void SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
|
||||
void SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
|
||||
void SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
|
||||
|
||||
private:
|
||||
inline static const string kKernelSize = "kernel_size";
|
||||
inline static const string kStride = "stride";
|
||||
inline static const string kDilation = "dilation";
|
||||
inline static const string kPadMode = "pad_mode";
|
||||
inline static const string kPad = "pad";
|
||||
inline static const string kMode = "mode";
|
||||
inline static const string kGroup = "group";
|
||||
inline static const string kOutputChannel = "output channel";
|
||||
inline static const string kPadList = "pad_list";
|
||||
inline static const string kConv2DName = "Conv2D";
|
||||
std::vector<int> GetKernelSize() const;
|
||||
std::vector<int> GetStride() const;
|
||||
std::vector<int> GetDilation() const;
|
||||
std::string GetPadMode() const;
|
||||
std::vector<int> GetPad() const;
|
||||
int GetMode() const;
|
||||
int GetGroup() const;
|
||||
int GetOutputChannel() const;
|
||||
void SetKernelSize(const std::vector<int> &kernel_size);
|
||||
void SetStride(const std::vector<int> &stride);
|
||||
void SetDilation(const std::vector<int> &dilation);
|
||||
void SetPadMode(const std::string &pad_mode);
|
||||
void SetPad(const std::vector<int> &pad);
|
||||
void SetMode(int mode);
|
||||
void SetGroup(int group);
|
||||
void SetOutChannel(int output_channel);
|
||||
void SetPadList(const std::vector<int> &pad_list);
|
||||
};
|
||||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimConv2dPtr = std::shared_ptr<Conv2d>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H
|
||||
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H_
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
|
||||
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
|
||||
#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
|
||||
#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/primitive.h"
|
||||
|
@ -25,7 +25,7 @@
|
|||
namespace mindspore {
|
||||
class PrimitiveC : public Primitive {
|
||||
public:
|
||||
explicit PrimitiveC(const std::string &name) : Primitive(name) { attrs_ = {}; }
|
||||
explicit PrimitiveC(const std::string &name) : Primitive(name) {}
|
||||
|
||||
protected:
|
||||
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name) {
|
||||
|
@ -34,4 +34,4 @@ class PrimitiveC : public Primitive {
|
|||
}
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H
|
||||
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_
|
||||
|
|
|
@ -61,6 +61,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/core/abstract/*.cc"
|
||||
"../../../mindspore/core/ir/*.cc"
|
||||
"../../../mindspore/core/utils/*.cc"
|
||||
"../../../mindspore/core/c_ops/*.cc"
|
||||
"../../../mindspore/ccsrc/common/*.cc"
|
||||
"../../../mindspore/ccsrc/utils/*.cc"
|
||||
"../../../mindspore/ccsrc/pipeline/jit/parse/*.cc"
|
||||
|
|
Loading…
Reference in New Issue