forked from mindspore-Ecosystem/mindspore
!3229 modify syn_code bug
Merge pull request !3229 from changzherui/mod_syc_code
This commit is contained in:
commit
30e27049a1
|
@ -31,13 +31,13 @@ std::string GetOpPythonPath(const OperatorName &op_name) {
|
|||
const std::string inner_ops_module = INNER_OP_PATH;
|
||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
|
||||
}
|
||||
return inner_ops_module;
|
||||
return ops_module;
|
||||
}
|
||||
return ops_module;
|
||||
return inner_ops_module;
|
||||
}
|
||||
|
||||
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||
|
|
|
@ -1,707 +0,0 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/static_analysis/prim.h"
|
||||
#include "pipeline/static_analysis/utils.h"
|
||||
#include "pipeline/static_analysis/param_validator.h"
|
||||
#include "operator/ops.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "ir/tensor_py.h"
|
||||
|
||||
using mindspore::tensor::TensorPy;
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
||||
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two scalars whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr value_x = scalar_x->BuildValue();
|
||||
ValuePtr value_y = scalar_y->BuildValue();
|
||||
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||
<< ", param1: " << value_y->ToString();
|
||||
}
|
||||
|
||||
std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractTuple>(args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractList>(args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
|
||||
size_t keys_size = keys->size();
|
||||
if (values->size() != keys_size) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
|
||||
}
|
||||
|
||||
std::vector<AbstractAttribute> key_value;
|
||||
AbstractScalarPtr key;
|
||||
AbstractBasePtrList key_list = keys->elements();
|
||||
AbstractBasePtrList value_list = values->elements();
|
||||
for (size_t index = 0; index < keys_size; index++) {
|
||||
key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
|
||||
ValuePtr keyPtr = key->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(keyPtr);
|
||||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
key_value.emplace_back(key_string, value_list[index]);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a string and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
|
||||
ValuePtr keyPtr = key->BuildValue();
|
||||
if (!keyPtr->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
|
||||
}
|
||||
std::string key_string = GetValue<std::string>(keyPtr);
|
||||
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a string and a keyword.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_input = GetValue<std::string>(key_value);
|
||||
std::string key_actual = kwarg->get_key();
|
||||
if (key_actual != key_input) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
|
||||
<< key_input << ", AbstractKeywordArg' key is " << key_actual;
|
||||
}
|
||||
return kwarg->get_arg();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: three scalars whose value is an int32 number.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 3);
|
||||
size_t args_size = args_spec_list.size();
|
||||
for (size_t index = 0; index < args_size; index++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||
if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
|
||||
MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
|
||||
}
|
||||
if (args_spec_list[index]->isa<AbstractScalar>() &&
|
||||
!dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number.";
|
||||
}
|
||||
}
|
||||
// Slice: start, end, step
|
||||
return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||
}
|
||||
|
||||
// Eval the return type of make_record
|
||||
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: at lease two objects of a subclass of AbstractBase.
|
||||
if (args_spec_list.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
|
||||
<< args_spec_list.size() << ".";
|
||||
}
|
||||
|
||||
// args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if (type->type_id() != kMetaTypeTypeType) {
|
||||
MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
|
||||
}
|
||||
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
TypePtr type_ptr = value_track->cast<TypePtr>();
|
||||
if (type_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
|
||||
}
|
||||
|
||||
auto cls = dyn_cast<Class>(type_ptr);
|
||||
MS_EXCEPTION_IF_NULL(cls);
|
||||
ClassAttrVector attributes = cls->GetAttributes();
|
||||
CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
|
||||
|
||||
std::vector<AbstractAttribute> abs_attributes;
|
||||
for (size_t i = 0; i < attributes.size(); i++) {
|
||||
AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
|
||||
abs_attributes.push_back(elem);
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list and a scalar whose value is an int32 number.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
int idx_v = GetValue<int>(index_value);
|
||||
std::size_t nelems = queue->elements().size();
|
||||
if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
|
||||
<< SizeToInt(nelems) << "), but got " << idx_v << ".";
|
||||
}
|
||||
|
||||
std::size_t uidx_v = 0;
|
||||
if (idx_v >= 0) {
|
||||
uidx_v = IntToSize(idx_v);
|
||||
} else {
|
||||
uidx_v = IntToSize(idx_v + SizeToInt(nelems));
|
||||
}
|
||||
return queue->elements()[uidx_v];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr index_value = index->BuildValue();
|
||||
if (!index_value->isa<Int32Imm>()) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||
<< index_value->ToString();
|
||||
}
|
||||
int idx_v = GetValue<int>(index_value);
|
||||
if (idx_v < 0) {
|
||||
MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
|
||||
<< ".";
|
||||
}
|
||||
|
||||
size_t uidx_v = IntToSize(idx_v);
|
||||
AbstractBasePtrList elements = queue->elements();
|
||||
std::size_t nelems = elements.size();
|
||||
if (uidx_v >= nelems) {
|
||||
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
|
||||
<< ".";
|
||||
}
|
||||
elements[uidx_v] = args_spec_list[2];
|
||||
return std::make_shared<T>(elements);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
auto key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
||||
if (it == dict_elems.end()) {
|
||||
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
std::string key_str = GetValue<std::string>(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[key_str](AbstractAttribute &item) { return item.first == key_str; });
|
||||
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
auto new_ele = std::make_pair(key_str, args_spec_list[2]);
|
||||
if (it != dict_elems.end()) {
|
||||
int index = it - dict_elems.begin();
|
||||
dict_elems[IntToSize(index)] = new_ele;
|
||||
} else {
|
||||
dict_elems.push_back(new_ele);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(dict_elems);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a list and an object of a subclass of AbstractBase.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
|
||||
(void)AbstractJoin(list->elements());
|
||||
return list;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple or list or dict.
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: fn, list1, list2, ...
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
if (args_spec_list.size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
|
||||
}
|
||||
AbstractFunctionPtr fn = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
|
||||
// check args from 1.
|
||||
CheckArgsSpec<AbstractList>(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end()));
|
||||
|
||||
AbstractBasePtrList subargs;
|
||||
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
|
||||
AbstractListPtr l_ptr = dyn_cast<AbstractList>(args_spec_list[i]);
|
||||
if (l_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list.";
|
||||
}
|
||||
subargs.push_back(AbstractJoin(l_ptr->elements()));
|
||||
}
|
||||
EvalResultPtr engin_exc = engine->Execute(fn, subargs);
|
||||
AbstractBasePtrList result;
|
||||
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
|
||||
result.push_back(engin_exc->abstract());
|
||||
}
|
||||
return std::make_shared<AbstractList>(result);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a fn, a list and an object of a subclass of a AbstractBase.
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
||||
AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
|
||||
AbstractBasePtr dflt = args_spec_list[2];
|
||||
|
||||
AbstractBasePtr list_type = AbstractJoin(lst->elements());
|
||||
auto result1 = engine->Execute(fn, lst->elements());
|
||||
auto result2 = engine->Execute(fn, {dflt, list_type});
|
||||
MS_EXCEPTION_IF_NULL(result1->abstract());
|
||||
MS_EXCEPTION_IF_NULL(result2->abstract());
|
||||
return result1->abstract()->Join(result2->abstract());
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
|
||||
auto tuple_elements = input->elements();
|
||||
AbstractBasePtrList elem_list;
|
||||
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
|
||||
[](const AbstractBasePtr &elem) { return elem->Clone(); });
|
||||
return std::make_shared<AbstractTuple>(elem_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
|
||||
const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
|
||||
size_t x_rank = x_shape->size();
|
||||
std::set<int> axis_set;
|
||||
auto axis_data = axis_value_ptr->value();
|
||||
if (axis_data.empty()) {
|
||||
int size = 1;
|
||||
AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
for (auto &elem : axis_data) {
|
||||
int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
|
||||
(void)axis_set.insert(e_value);
|
||||
}
|
||||
|
||||
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
|
||||
if (x_shp_data.size() < x_rank) {
|
||||
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
|
||||
}
|
||||
AbstractBasePtrList values;
|
||||
for (size_t i = 0; i < x_rank; i++) {
|
||||
if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
|
||||
auto axis_v = MakeValue(1);
|
||||
values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
|
||||
} else {
|
||||
int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
|
||||
auto dim = MakeValue(dim_value);
|
||||
values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: x_shape, axis
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
|
||||
auto x_shp_value = shape_x->BuildValue();
|
||||
if (x_shp_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << op_name
|
||||
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
||||
}
|
||||
|
||||
// Axis can be scalar, tuple or None
|
||||
AbstractTuplePtr axis = nullptr;
|
||||
if (args_spec_list[1]->isa<AbstractScalar>()) {
|
||||
MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
|
||||
AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
|
||||
axis = std::make_shared<AbstractTuple>(axis_list);
|
||||
} else if (args_spec_list[1]->isa<AbstractTuple>()) {
|
||||
MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
|
||||
axis = args_spec_list[1]->cast<AbstractTuplePtr>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
|
||||
<< args_spec_list[1]->ToString();
|
||||
}
|
||||
|
||||
auto axis_value = axis->BuildValue();
|
||||
if (axis_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << op_name
|
||||
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
||||
}
|
||||
auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(axis_value_ptr);
|
||||
|
||||
return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();
|
||||
|
||||
auto div_shp_value = div_shp->BuildValue();
|
||||
if (div_shp_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString();
|
||||
}
|
||||
|
||||
auto shpx_value = shape_x->BuildValue();
|
||||
if (shpx_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString();
|
||||
}
|
||||
|
||||
if (div_shp->size() != shape_x->size()) {
|
||||
MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size()
|
||||
<< ", shapex: " << shape_x->size() << ".";
|
||||
}
|
||||
|
||||
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
|
||||
auto div_shp_data = div_shp_value->cast<ValueTuplePtr>()->value();
|
||||
AbstractBasePtrList values;
|
||||
|
||||
for (size_t i = 0; i < div_shp_data.size(); i++) {
|
||||
if (div_shp_data[i]->cast<Int32ImmPtr>() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString();
|
||||
}
|
||||
int shapex_value = GetValue<int>(shpx_data[i]);
|
||||
int div_value = GetValue<int>(div_shp_data[i]);
|
||||
MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
|
||||
if (div_value == 0) {
|
||||
MS_LOG(EXCEPTION) << "error: division value should not be 0!";
|
||||
}
|
||||
if ((shapex_value % div_value) != 0) {
|
||||
MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value;
|
||||
}
|
||||
|
||||
int result = shapex_value / div_value;
|
||||
auto result_v = MakeValue(result);
|
||||
values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
|
||||
py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
|
||||
py::array data = py::array(data_tuple);
|
||||
auto tensor = TensorPy::MakeTensor(data);
|
||||
auto ret = tensor->ToAbstract();
|
||||
ret->set_value(tensor);
|
||||
MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tuple
|
||||
// example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
|
||||
auto shpx_value = shape_x->BuildValue();
|
||||
if (shpx_value->isa<AnyValue>()) {
|
||||
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString();
|
||||
}
|
||||
|
||||
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
|
||||
|
||||
int result = 1;
|
||||
for (size_t i = 0; i < shpx_data.size(); i++) {
|
||||
int value = GetValue<int>(shpx_data[i]);
|
||||
result = IntMulWithOverflowCheck(result, value);
|
||||
}
|
||||
|
||||
auto result_v = MakeValue(result);
|
||||
MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString();
|
||||
return std::make_shared<AbstractScalar>(result_v, result_v->type());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tuples or two lists.
|
||||
CheckArgsSize(op_name, args_spec_list, 2);
|
||||
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
||||
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
||||
|
||||
ValuePtr x_value = input_x->BuildValue();
|
||||
ValuePtr y_value = input_y->BuildValue();
|
||||
return std::make_shared<AbstractScalar>(*x_value == *y_value);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
|
||||
}
|
||||
|
||||
struct SlideInfo {
|
||||
int start;
|
||||
int step;
|
||||
int stop;
|
||||
};
|
||||
|
||||
void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
|
||||
int arg1 = 0;
|
||||
int arg2 = 0;
|
||||
if (!args_spec_list.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
auto arg_value = args_spec_list[0]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
arg1 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
if (args_spec_list.size() >= 2) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||
auto arg_value = args_spec_list[1]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
arg2 = GetValue<int>(arg_value);
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 3) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||
auto arg_value = args_spec_list[2]->BuildValue();
|
||||
if (!arg_value->isa<Int32Imm>()) {
|
||||
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||
}
|
||||
slide->step = GetValue<int>(arg_value);
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 2) {
|
||||
slide->start = arg1;
|
||||
slide->stop = arg2;
|
||||
}
|
||||
|
||||
if (args_spec_list.size() == 1) {
|
||||
slide->stop = arg1;
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Cannot make range from empty input.";
|
||||
}
|
||||
|
||||
if (args_spec_list.size() > 3) {
|
||||
MS_LOG(EXCEPTION) << "Error args size of make range operational.";
|
||||
}
|
||||
|
||||
SlideInfo slide = {0, 1, 0};
|
||||
CalcSlidePara(args_spec_list, &slide);
|
||||
|
||||
if (slide.step == 0) {
|
||||
MS_LOG(EXCEPTION) << "Error, step value is 0.";
|
||||
}
|
||||
|
||||
AbstractBasePtrList args;
|
||||
if (slide.start <= slide.stop) {
|
||||
if (slide.step <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
|
||||
}
|
||||
for (int i = slide.start; i < slide.stop; i += slide.step) {
|
||||
args.push_back(abstract::FromValue(i));
|
||||
}
|
||||
} else {
|
||||
if (slide.step >= 0) {
|
||||
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
|
||||
}
|
||||
for (int i = slide.start; i > slide.stop; i += slide.step) {
|
||||
args.push_back(abstract::FromValue(i));
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<AbstractTuple>(args);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
return args_spec_list[0]->Clone();
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
|
@ -1,93 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ir/pattern_matcher.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimMakeRef, X, Y, Z} -> Y
|
||||
class MakeRefEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> x, y, z;
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y);
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimGetRefValue, Parameter} -> Parameter
|
||||
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
|
||||
class GetRefParamEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> x;
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
||||
class GetMakeRefEliminater : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> x, y, z;
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
|
||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// IsValueNode<RefKey>
|
||||
class ReplaceRefkeyByParam : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr {
|
||||
auto refkey = GetValueNode<RefKeyPtr>(node);
|
||||
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
|
||||
auto top_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(top_graph);
|
||||
|
||||
for (const auto &tnode : top_graph->parameters()) {
|
||||
auto para = tnode->cast<ParameterPtr>();
|
||||
if (para != nullptr && para->name() == refkey->tag()) {
|
||||
return para;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
PatternNode<AnfNodePtr> x;
|
||||
MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
|
|
@ -1,175 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "parallel/graph_util/generate_graph.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
using mindspore::tensor::Tensor;
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
std::string GetOpPythonPath(const OperatorName &op_name) {
|
||||
// almost all ops are defined in two main paths
|
||||
const std::string ops_module = OP_PATH;
|
||||
const std::string inner_ops_module = INNER_OP_PATH;
|
||||
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
|
||||
}
|
||||
return ops_module;
|
||||
}
|
||||
return inner_ops_module;
|
||||
}
|
||||
|
||||
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||
std::string op_path = GetOpPythonPath(op_name);
|
||||
py::module mod = py::module::import(common::SafeCStr(op_path));
|
||||
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||
MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<py::object> arg_list;
|
||||
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
|
||||
[](const Attr &attr) { return ValuePtrToPyData(attr.second); });
|
||||
py::object obj =
|
||||
parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
|
||||
ValuePtr op_instance = nullptr;
|
||||
bool succ = parse::ConvertData(obj, &op_instance);
|
||||
if (!succ) {
|
||||
MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail";
|
||||
return nullptr;
|
||||
}
|
||||
return op_instance;
|
||||
}
|
||||
|
||||
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) {
|
||||
auto value_node = NewValueNode(value_ptr);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
return value_node->cast<AnfNodePtr>();
|
||||
}
|
||||
|
||||
static std::unordered_map<int32_t, AnfNodePtr> int_tensor_map = {};
|
||||
AnfNodePtr CreateInt32Tensor(int32_t value) {
|
||||
auto it = int_tensor_map.find(value);
|
||||
if (it != int_tensor_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(py::int_(value), kInt32);
|
||||
ValuePtr value_ptr = MakeValue(tensor_ptr);
|
||||
auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
|
||||
int_tensor_map[value] = anf_node_ptr;
|
||||
return anf_node_ptr;
|
||||
}
|
||||
|
||||
AnfNodePtr CreatTypeInt(int32_t value) {
|
||||
ValuePtr value_ptr = MakeValue(std::make_shared<Int>(value));
|
||||
return ValuePtrToAnfNodePtr(value_ptr);
|
||||
}
|
||||
|
||||
AnfNodePtr CreatInt32Imm(int32_t value) {
|
||||
ValuePtr value_ptr = MakeValue(std::make_shared<Int32Imm>(value));
|
||||
return ValuePtrToAnfNodePtr(value_ptr);
|
||||
}
|
||||
|
||||
std::string GetInstanceNameByCNode(const CNodePtr &cnode) {
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (!prim) {
|
||||
MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr.";
|
||||
}
|
||||
std::string instance_name = prim->instance_name();
|
||||
return HashInstanceName(instance_name);
|
||||
}
|
||||
|
||||
std::string HashInstanceName(const std::string &name) {
|
||||
auto using_hash_name = common::GetEnv(USING_HASH_NAME);
|
||||
std::string instance_name;
|
||||
if ((using_hash_name.empty()) || (using_hash_name == "on")) {
|
||||
instance_name = HashName(name);
|
||||
} else {
|
||||
instance_name = name;
|
||||
}
|
||||
return instance_name;
|
||||
}
|
||||
|
||||
Status GenerateGraph::Init(const CNodePtr &cnode) {
|
||||
if (!cnode) {
|
||||
MS_LOG(ERROR) << "Init:cnode is nullptr";
|
||||
return FAILED;
|
||||
}
|
||||
cnode_ = cnode;
|
||||
func_graph_ = cnode->func_graph();
|
||||
if (!func_graph_) {
|
||||
MS_LOG(ERROR) << "Init:func_graph_ is nullptr";
|
||||
return FAILED;
|
||||
}
|
||||
manager_ = func_graph_->manager();
|
||||
if (!manager_) {
|
||||
MS_LOG(ERROR) << "Init:manager_ is nullptr";
|
||||
return FAILED;
|
||||
}
|
||||
scope_ = cnode_->scope();
|
||||
if (!scope_) {
|
||||
MS_LOG(ERROR) << "Init:scope_ is nullptr";
|
||||
return FAILED;
|
||||
}
|
||||
virtual_input_node_ = std::make_shared<AnfNode>(nullptr);
|
||||
virtual_input_node_->set_scope(scope_);
|
||||
instance_name_base_ = GetInstanceNameByCNode(cnode_);
|
||||
name_idx_ = 0;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
|
||||
CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
cnode->set_scope(scope_);
|
||||
if (inputs.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "inputs.size() must be more than 1";
|
||||
}
|
||||
(void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0]
|
||||
auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
|
||||
return new_anf_node_ptr;
|
||||
}
|
||||
|
||||
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
|
||||
name_idx_++;
|
||||
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
|
||||
if (pyop_instance == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
|
||||
}
|
||||
auto value_node = NewValueNode(pyop_instance);
|
||||
return value_node->cast<AnfNodePtr>();
|
||||
}
|
||||
|
||||
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
|
||||
name_idx_++;
|
||||
OperatorAttrs attrs;
|
||||
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
|
||||
if (pyop_instance == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
|
||||
}
|
||||
auto value_node = NewValueNode(pyop_instance);
|
||||
return value_node->cast<AnfNodePtr>();
|
||||
}
|
||||
} // namespace parallel
|
||||
} // namespace mindspore
|
|
@ -330,7 +330,6 @@ __all__ = [
|
|||
"ApplyCenteredRMSProp",
|
||||
"SpaceToBatchND",
|
||||
"BatchToSpaceND",
|
||||
"ReverseSequence",
|
||||
"SquareSumAll",
|
||||
"BitwiseAnd",
|
||||
"BitwiseOr",
|
||||
|
@ -346,12 +345,10 @@ __all__ = [
|
|||
"ApproximateEqual",
|
||||
"InplaceUpdate",
|
||||
"InTopK",
|
||||
"CropAndResize",
|
||||
"LRN",
|
||||
"Mod",
|
||||
"PopulationCount",
|
||||
"ParallelConcat",
|
||||
"EmbeddingLookup",
|
||||
"Push",
|
||||
"Pull"
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue