mod syn_code bug

This commit is contained in:
changzherui 2020-07-20 13:05:04 +08:00
parent 28f873e9ad
commit dbe4cc32ca
5 changed files with 4 additions and 982 deletions

View File

@ -31,13 +31,13 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
const std::string inner_ops_module = INNER_OP_PATH; const std::string inner_ops_module = INNER_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module)); py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module)); py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
if (!py::hasattr(mod, common::SafeCStr(op_name))) { if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) { if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name; MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
} }
return inner_ops_module; return ops_module;
} }
return ops_module; return inner_ops_module;
} }
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {

View File

@ -1,707 +0,0 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/utils.h"
#include "pipeline/static_analysis/param_validator.h"
#include "operator/ops.h"
#include "utils/convert_utils.h"
#include "ir/tensor_py.h"
using mindspore::tensor::TensorPy;
namespace mindspore {
namespace abstract {
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two scalars whose value is a string.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr value_x = scalar_x->BuildValue();
ValuePtr value_y = scalar_y->BuildValue();
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
<< ", param1: " << value_y->ToString();
}
bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
return std::make_shared<AbstractScalar>(ret);
}
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two scalars whose value is a string.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr value_x = scalar_x->BuildValue();
ValuePtr value_y = scalar_y->BuildValue();
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
<< ", param1: " << value_y->ToString();
}
std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
return std::make_shared<AbstractScalar>(ret);
}
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
return std::make_shared<AbstractTuple>(args_spec_list);
}
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
return std::make_shared<AbstractList>(args_spec_list);
}
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tuples.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
size_t keys_size = keys->size();
if (values->size() != keys_size) {
MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
}
std::vector<AbstractAttribute> key_value;
AbstractScalarPtr key;
AbstractBasePtrList key_list = keys->elements();
AbstractBasePtrList value_list = values->elements();
for (size_t index = 0; index < keys_size; index++) {
key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
ValuePtr keyPtr = key->BuildValue();
MS_EXCEPTION_IF_NULL(keyPtr);
if (!keyPtr->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
}
std::string key_string = GetValue<std::string>(keyPtr);
key_value.emplace_back(key_string, value_list[index]);
}
return std::make_shared<AbstractDictionary>(key_value);
}
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a string and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
ValuePtr keyPtr = key->BuildValue();
if (!keyPtr->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
}
std::string key_string = GetValue<std::string>(keyPtr);
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
}
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a string and a keyword.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
std::string key_input = GetValue<std::string>(key_value);
std::string key_actual = kwarg->get_key();
if (key_actual != key_input) {
MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
<< key_input << ", AbstractKeywordArg' key is " << key_actual;
}
return kwarg->get_arg();
}
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: three scalars whose value is an int32 number.
CheckArgsSize(primitive->name(), args_spec_list, 3);
size_t args_size = args_spec_list.size();
for (size_t index = 0; index < args_size; index++) {
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
}
if (args_spec_list[index]->isa<AbstractScalar>() &&
!dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number.";
}
}
// Slice: start, end, step
return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
}
// Eval the return type of make_record
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: at lease two objects of a subclass of AbstractBase.
if (args_spec_list.size() < 2) {
MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
<< args_spec_list.size() << ".";
}
// args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
TypePtr type = args_spec_list[0]->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
if (type->type_id() != kMetaTypeTypeType) {
MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
}
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
TypePtr type_ptr = value_track->cast<TypePtr>();
if (type_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
}
auto cls = dyn_cast<Class>(type_ptr);
MS_EXCEPTION_IF_NULL(cls);
ClassAttrVector attributes = cls->GetAttributes();
CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
std::vector<AbstractAttribute> abs_attributes;
for (size_t i = 0; i < attributes.size(); i++) {
AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
abs_attributes.push_back(elem);
}
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
}
template <typename T>
AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list and a scalar whose value is an int32 number.
CheckArgsSize(op_name, args_spec_list, 2);
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr index_value = index->BuildValue();
if (!index_value->isa<Int32Imm>()) {
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
<< index_value->ToString();
}
int idx_v = GetValue<int>(index_value);
std::size_t nelems = queue->elements().size();
if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
<< SizeToInt(nelems) << "), but got " << idx_v << ".";
}
std::size_t uidx_v = 0;
if (idx_v >= 0) {
uidx_v = IntToSize(idx_v);
} else {
uidx_v = IntToSize(idx_v + SizeToInt(nelems));
}
return queue->elements()[uidx_v];
}
template <typename T>
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
CheckArgsSize(op_name, args_spec_list, 3);
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr index_value = index->BuildValue();
if (!index_value->isa<Int32Imm>()) {
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
<< index_value->ToString();
}
int idx_v = GetValue<int>(index_value);
if (idx_v < 0) {
MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
<< ".";
}
size_t uidx_v = IntToSize(idx_v);
AbstractBasePtrList elements = queue->elements();
std::size_t nelems = elements.size();
if (uidx_v >= nelems) {
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
<< ".";
}
elements[uidx_v] = args_spec_list[2];
return std::make_shared<T>(elements);
}
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict and a scalar whose value is a string.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
auto key_str = GetValue<std::string>(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
if (it == dict_elems.end()) {
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
}
return it->second;
}
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
ValuePtr key_value = key->BuildValue();
if (!key_value->isa<StringImm>()) {
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
}
std::string key_str = GetValue<std::string>(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
[key_str](AbstractAttribute &item) { return item.first == key_str; });
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
auto new_ele = std::make_pair(key_str, args_spec_list[2]);
if (it != dict_elems.end()) {
int index = it - dict_elems.begin();
dict_elems[IntToSize(index)] = new_ele;
} else {
dict_elems.push_back(new_ele);
}
return std::make_shared<AbstractDictionary>(dict_elems);
}
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a list and an object of a subclass of AbstractBase.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
(void)AbstractJoin(list->elements());
return list;
}
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple or list or dict.
CheckArgsSize(op_name, args_spec_list, 1);
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
}
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
}
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: fn, list1, list2, ...
MS_EXCEPTION_IF_NULL(engine);
if (args_spec_list.size() <= 1) {
MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
}
AbstractFunctionPtr fn = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
// check args from 1.
CheckArgsSpec<AbstractList>(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end()));
AbstractBasePtrList subargs;
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
AbstractListPtr l_ptr = dyn_cast<AbstractList>(args_spec_list[i]);
if (l_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list.";
}
subargs.push_back(AbstractJoin(l_ptr->elements()));
}
EvalResultPtr engin_exc = engine->Execute(fn, subargs);
AbstractBasePtrList result;
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
result.push_back(engin_exc->abstract());
}
return std::make_shared<AbstractList>(result);
}
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a fn, a list and an object of a subclass of a AbstractBase.
MS_EXCEPTION_IF_NULL(engine);
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
AbstractBasePtr dflt = args_spec_list[2];
AbstractBasePtr list_type = AbstractJoin(lst->elements());
auto result1 = engine->Execute(fn, lst->elements());
auto result2 = engine->Execute(fn, {dflt, list_type});
MS_EXCEPTION_IF_NULL(result1->abstract());
MS_EXCEPTION_IF_NULL(result2->abstract());
return result1->abstract()->Join(result2->abstract());
}
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
auto tuple_elements = input->elements();
AbstractBasePtrList elem_list;
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
[](const AbstractBasePtr &elem) { return elem->Clone(); });
return std::make_shared<AbstractTuple>(elem_list);
}
AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
size_t x_rank = x_shape->size();
std::set<int> axis_set;
auto axis_data = axis_value_ptr->value();
if (axis_data.empty()) {
int size = 1;
AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
return std::make_shared<AbstractTuple>(values);
}
for (auto &elem : axis_data) {
int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
(void)axis_set.insert(e_value);
}
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
if (x_shp_data.size() < x_rank) {
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
}
AbstractBasePtrList values;
for (size_t i = 0; i < x_rank; i++) {
if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
auto axis_v = MakeValue(1);
values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
} else {
int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
auto dim = MakeValue(dim_value);
values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
}
}
return std::make_shared<AbstractTuple>(values);
}
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: x_shape, axis
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
auto x_shp_value = shape_x->BuildValue();
if (x_shp_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << op_name
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
}
// Axis can be scalar, tuple or None
AbstractTuplePtr axis = nullptr;
if (args_spec_list[1]->isa<AbstractScalar>()) {
MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
axis = std::make_shared<AbstractTuple>(axis_list);
} else if (args_spec_list[1]->isa<AbstractTuple>()) {
MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
axis = args_spec_list[1]->cast<AbstractTuplePtr>();
} else {
MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
<< args_spec_list[1]->ToString();
}
auto axis_value = axis->BuildValue();
if (axis_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << op_name
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
}
auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(axis_value_ptr);
return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
}
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tuples.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 2);
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();
auto div_shp_value = div_shp->BuildValue();
if (div_shp_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString();
}
auto shpx_value = shape_x->BuildValue();
if (shpx_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString();
}
if (div_shp->size() != shape_x->size()) {
MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size()
<< ", shapex: " << shape_x->size() << ".";
}
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
auto div_shp_data = div_shp_value->cast<ValueTuplePtr>()->value();
AbstractBasePtrList values;
for (size_t i = 0; i < div_shp_data.size(); i++) {
if (div_shp_data[i]->cast<Int32ImmPtr>() == nullptr) {
MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString();
}
int shapex_value = GetValue<int>(shpx_data[i]);
int div_value = GetValue<int>(div_shp_data[i]);
MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
if (div_value == 0) {
MS_LOG(EXCEPTION) << "error: division value should not be 0!";
}
if ((shapex_value % div_value) != 0) {
MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value;
}
int result = shapex_value / div_value;
auto result_v = MakeValue(result);
values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
}
return std::make_shared<AbstractTuple>(values);
}
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
py::array data = py::array(data_tuple);
auto tensor = TensorPy::MakeTensor(data);
auto ret = tensor->ToAbstract();
ret->set_value(tensor);
MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
return ret;
}
AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tuple
// example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
auto shpx_value = shape_x->BuildValue();
if (shpx_value->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString();
}
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
int result = 1;
for (size_t i = 0; i < shpx_data.size(); i++) {
int value = GetValue<int>(shpx_data[i]);
result = IntMulWithOverflowCheck(result, value);
}
auto result_v = MakeValue(result);
MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString();
return std::make_shared<AbstractScalar>(result_v, result_v->type());
}
template <typename T>
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
// Inputs: two tuples or two lists.
CheckArgsSize(op_name, args_spec_list, 2);
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
ValuePtr x_value = input_x->BuildValue();
ValuePtr y_value = input_y->BuildValue();
return std::make_shared<AbstractScalar>(*x_value == *y_value);
}
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
}
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
}
struct SlideInfo {
int start;
int step;
int stop;
};
void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
int arg1 = 0;
int arg2 = 0;
if (!args_spec_list.empty()) {
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
auto arg_value = args_spec_list[0]->BuildValue();
if (!arg_value->isa<Int32Imm>()) {
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
}
arg1 = GetValue<int>(arg_value);
}
if (args_spec_list.size() >= 2) {
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
auto arg_value = args_spec_list[1]->BuildValue();
if (!arg_value->isa<Int32Imm>()) {
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
}
arg2 = GetValue<int>(arg_value);
}
if (args_spec_list.size() == 3) {
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
auto arg_value = args_spec_list[2]->BuildValue();
if (!arg_value->isa<Int32Imm>()) {
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
}
slide->step = GetValue<int>(arg_value);
slide->start = arg1;
slide->stop = arg2;
}
if (args_spec_list.size() == 2) {
slide->start = arg1;
slide->stop = arg2;
}
if (args_spec_list.size() == 1) {
slide->stop = arg1;
}
}
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "Cannot make range from empty input.";
}
if (args_spec_list.size() > 3) {
MS_LOG(EXCEPTION) << "Error args size of make range operational.";
}
SlideInfo slide = {0, 1, 0};
CalcSlidePara(args_spec_list, &slide);
if (slide.step == 0) {
MS_LOG(EXCEPTION) << "Error, step value is 0.";
}
AbstractBasePtrList args;
if (slide.start <= slide.stop) {
if (slide.step <= 0) {
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
}
for (int i = slide.start; i < slide.stop; i += slide.step) {
args.push_back(abstract::FromValue(i));
}
} else {
if (slide.step >= 0) {
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
}
for (int i = slide.start; i > slide.stop; i += slide.step) {
args.push_back(abstract::FromValue(i));
}
}
return std::make_shared<AbstractTuple>(args);
}
AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor
CheckArgsSize(primitive->name(), args_spec_list, 1);
return args_spec_list[0]->Clone();
}
} // namespace abstract
} // namespace mindspore

View File

@ -1,93 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
#include <memory>
#include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimMakeRef, X, Y, Z} -> Y
class MakeRefEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y);
return nullptr;
}
};
// {prim::kPrimGetRefValue, Parameter} -> Parameter
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
class GetRefParamEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
return nullptr;
}
};
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class GetMakeRefEliminater : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, y, z;
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
return nullptr;
}
};
// IsValueNode<RefKey>
class ReplaceRefkeyByParam : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr {
auto refkey = GetValueNode<RefKeyPtr>(node);
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
MS_EXCEPTION_IF_NULL(resource);
auto top_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(top_graph);
for (const auto &tnode : top_graph->parameters()) {
auto para = tnode->cast<ParameterPtr>();
if (para != nullptr && para->name() == refkey->tag()) {
return para;
}
}
return nullptr;
};
PatternNode<AnfNodePtr> x;
MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node));
return nullptr;
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_

View File

@ -1,175 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/graph_util/generate_graph.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
using mindspore::tensor::Tensor;
namespace mindspore {
namespace parallel {
std::string GetOpPythonPath(const OperatorName &op_name) {
// almost all ops are defined in two main paths
const std::string ops_module = OP_PATH;
const std::string inner_ops_module = INNER_OP_PATH;
py::module mod = py::module::import(common::SafeCStr(ops_module));
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
}
return ops_module;
}
return inner_ops_module;
}
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
std::string op_path = GetOpPythonPath(op_name);
py::module mod = py::module::import(common::SafeCStr(op_path));
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
return nullptr;
}
std::vector<py::object> arg_list;
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
[](const Attr &attr) { return ValuePtrToPyData(attr.second); });
py::object obj =
parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
ValuePtr op_instance = nullptr;
bool succ = parse::ConvertData(obj, &op_instance);
if (!succ) {
MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail";
return nullptr;
}
return op_instance;
}
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) {
auto value_node = NewValueNode(value_ptr);
MS_EXCEPTION_IF_NULL(value_node);
return value_node->cast<AnfNodePtr>();
}
static std::unordered_map<int32_t, AnfNodePtr> int_tensor_map = {};
AnfNodePtr CreateInt32Tensor(int32_t value) {
auto it = int_tensor_map.find(value);
if (it != int_tensor_map.end()) {
return it->second;
}
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(py::int_(value), kInt32);
ValuePtr value_ptr = MakeValue(tensor_ptr);
auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
int_tensor_map[value] = anf_node_ptr;
return anf_node_ptr;
}
AnfNodePtr CreatTypeInt(int32_t value) {
ValuePtr value_ptr = MakeValue(std::make_shared<Int>(value));
return ValuePtrToAnfNodePtr(value_ptr);
}
AnfNodePtr CreatInt32Imm(int32_t value) {
ValuePtr value_ptr = MakeValue(std::make_shared<Int32Imm>(value));
return ValuePtrToAnfNodePtr(value_ptr);
}
std::string GetInstanceNameByCNode(const CNodePtr &cnode) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (!prim) {
MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr.";
}
std::string instance_name = prim->instance_name();
return HashInstanceName(instance_name);
}
std::string HashInstanceName(const std::string &name) {
auto using_hash_name = common::GetEnv(USING_HASH_NAME);
std::string instance_name;
if ((using_hash_name.empty()) || (using_hash_name == "on")) {
instance_name = HashName(name);
} else {
instance_name = name;
}
return instance_name;
}
Status GenerateGraph::Init(const CNodePtr &cnode) {
if (!cnode) {
MS_LOG(ERROR) << "Init:cnode is nullptr";
return FAILED;
}
cnode_ = cnode;
func_graph_ = cnode->func_graph();
if (!func_graph_) {
MS_LOG(ERROR) << "Init:func_graph_ is nullptr";
return FAILED;
}
manager_ = func_graph_->manager();
if (!manager_) {
MS_LOG(ERROR) << "Init:manager_ is nullptr";
return FAILED;
}
scope_ = cnode_->scope();
if (!scope_) {
MS_LOG(ERROR) << "Init:scope_ is nullptr";
return FAILED;
}
virtual_input_node_ = std::make_shared<AnfNode>(nullptr);
virtual_input_node_->set_scope(scope_);
instance_name_base_ = GetInstanceNameByCNode(cnode_);
name_idx_ = 0;
return SUCCESS;
}
AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_scope(scope_);
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "inputs.size() must be more than 1";
}
(void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0]
auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
return new_anf_node_ptr;
}
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
name_idx_++;
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
if (pyop_instance == nullptr) {
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
}
auto value_node = NewValueNode(pyop_instance);
return value_node->cast<AnfNodePtr>();
}
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
name_idx_++;
OperatorAttrs attrs;
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
if (pyop_instance == nullptr) {
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
}
auto value_node = NewValueNode(pyop_instance);
return value_node->cast<AnfNodePtr>();
}
} // namespace parallel
} // namespace mindspore

View File

@ -327,7 +327,6 @@ __all__ = [
"ApplyCenteredRMSProp", "ApplyCenteredRMSProp",
"SpaceToBatchND", "SpaceToBatchND",
"BatchToSpaceND", "BatchToSpaceND",
"ReverseSequence",
"SquareSumAll", "SquareSumAll",
"BitwiseAnd", "BitwiseAnd",
"BitwiseOr", "BitwiseOr",
@ -343,12 +342,10 @@ __all__ = [
"ApproximateEqual", "ApproximateEqual",
"InplaceUpdate", "InplaceUpdate",
"InTopK", "InTopK",
"CropAndResize",
"LRN", "LRN",
"Mod", "Mod",
"PopulationCount", "PopulationCount",
"ParallelConcat", "ParallelConcat",
"EmbeddingLookup",
"Push", "Push",
"Pull" "Pull"
] ]