!28243 [ME] Fix magic numbers; Add tensor type check with np.str_.

Merge pull request !28243 from Margaret_wangrui/tensor
This commit is contained in:
i-robot 2021-12-28 01:37:41 +00:00 committed by Gitee
commit 348f8d8fd7
5 changed files with 21 additions and 24 deletions

View File

@ -108,21 +108,23 @@ AnfNodePtr HyperMap::FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &
return func_graph->NewCNodeInOrder(inputs);
}
std::vector<std::string> HyperMap::GetHyperMapInputIndex(size_t num) {
std::pair<std::string, std::string> HyperMap::GetHyperMapInputIndex(size_t num) {
std::string error_index;
std::string next_index;
if (num == 1) {
const size_t first_index = 1;
const size_t second_index = 2;
if (num == first_index) {
// The first element in HyperMap is func_graph
error_index = "first";
next_index = "second";
} else if (num == 2) {
} else if (num == second_index) {
error_index = "second";
next_index = "third";
} else {
error_index = std::to_string(num) + "th";
next_index = std::to_string(num + 1) + "th";
}
return {error_index, next_index};
return std::pair<std::string, std::string>(error_index, next_index);
}
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
@ -137,9 +139,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph
for (auto &item : arg_map) {
num++;
auto lhs = std::static_pointer_cast<List>(item.second);
std::vector<std::string> indexes = GetHyperMapInputIndex(num);
std::string error_index = indexes[0];
std::string next_index = indexes[1];
auto [error_index, next_index] = GetHyperMapInputIndex(num);
if (lhs == nullptr) {
MS_LOG(EXCEPTION) << "The " << error_index << " element in HyperMap has wrong type, expected a List, but got "
<< item.second->ToString() << ".";
@ -199,9 +199,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
for (auto &item : arg_map) {
num++;
auto lhs = std::static_pointer_cast<Tuple>(item.second);
std::vector<std::string> indexes = GetHyperMapInputIndex(num);
std::string error_index = indexes[0];
std::string next_index = indexes[1];
auto [error_index, next_index] = GetHyperMapInputIndex(num);
if (lhs == nullptr) {
MS_LOG(EXCEPTION) << "The " << error_index << " element in HyperMap has wrong type, expected a Tuple, but got "
<< item.second->ToString() << ".";
@ -317,6 +315,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_a
<< trace::GetDebugInfo(func_graph->debug_info()) << "\n";
int64_t idx = 0;
std::string str_index = "first";
const size_t diff_index = 2;
for (auto &item : arg_map) {
// The first element in HyperMap is func_graph
if (idx == 0) {
@ -324,7 +323,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_a
} else if (idx == 1) {
str_index = "third";
} else {
str_index = std::to_string(idx + 2) + "th";
str_index = std::to_string(idx + diff_index) + "th";
}
++idx;
oss << "The type of the " << str_index << " argument in HyperMap is " << item.second->ToString() << ".\n";

View File

@ -81,7 +81,7 @@ class HyperMap : public MetaFuncGraph {
const ArgsPairList &arg_map);
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
std::vector<std::string> GetHyperMapInputIndex(size_t num);
std::pair<std::string, std::string> GetHyperMapInputIndex(size_t num);
MultitypeFuncGraphPtr fn_leaf_;
bool reverse_;

View File

@ -65,21 +65,23 @@ FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
return ptrGraph;
}
std::vector<std::string> Map::GetMapInputIndex(size_t num) {
std::pair<std::string, std::string> Map::GetMapInputIndex(size_t num) {
std::string error_index;
std::string next_index;
if (num == 1) {
const size_t first_index = 1;
const size_t second_index = 2;
if (num == first_index) {
// The first element in Map is func_graph
error_index = "first";
next_index = "second";
} else if (num == 2) {
} else if (num == second_index) {
error_index = "second";
next_index = "third";
} else {
error_index = std::to_string(num) + "th";
next_index = std::to_string(num + 1) + "th";
}
return {error_index, next_index};
return std::pair<std::string, std::string>(error_index, next_index);
}
AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
@ -94,9 +96,7 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphP
for (auto &item : arg_pairs) {
num++;
auto lhs = std::dynamic_pointer_cast<List>(item.second);
std::vector<std::string> indexes = GetMapInputIndex(num);
std::string error_index = indexes[0];
std::string next_index = indexes[1];
auto [error_index, next_index] = GetMapInputIndex(num);
if (lhs == nullptr) {
MS_LOG(EXCEPTION) << "The " << error_index << " element in Map has wrong type, expected a List, but got "
<< item.second->ToString() << ".";
@ -157,9 +157,7 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap
for (auto &item : arg_pairs) {
num++;
auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
std::vector<std::string> indexes = GetMapInputIndex(num);
std::string error_index = indexes[0];
std::string next_index = indexes[1];
auto [error_index, next_index] = GetMapInputIndex(num);
if (lhs == nullptr) {
MS_LOG(EXCEPTION) << "The " << error_index << " element in Map has wrong type, expected a Tuple, but got "
<< item.second->ToString() << ".";

View File

@ -78,7 +78,7 @@ class Map : public MetaFuncGraph {
AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
const ArgsPairList &arg_pairs);
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs);
std::vector<std::string> GetMapInputIndex(size_t num);
std::pair<std::string, std::string> GetMapInputIndex(size_t num);
void Init() {
if (fn_leaf_ != nullptr) {
name_ = "map[" + fn_leaf_->name() + "]";

View File

@ -118,7 +118,7 @@ class Tensor(Tensor_):
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
if init is None:
validator.check_value_type('input_data', input_data,
(Tensor_, np.ndarray, list, tuple, float, int, bool, complex), 'Tensor')
(Tensor_, np.ndarray, np.str_, list, tuple, float, int, bool, complex), 'Tensor')
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \