forked from mindspore-Ecosystem/mindspore
!28243 [ME] Fix magic numbers; Add tensor type check with np.str_.
Merge pull request !28243 from Margaret_wangrui/tensor
This commit is contained in:
commit
348f8d8fd7
|
@ -108,21 +108,23 @@ AnfNodePtr HyperMap::FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
return func_graph->NewCNodeInOrder(inputs);
|
||||
}
|
||||
|
||||
std::vector<std::string> HyperMap::GetHyperMapInputIndex(size_t num) {
|
||||
std::pair<std::string, std::string> HyperMap::GetHyperMapInputIndex(size_t num) {
|
||||
std::string error_index;
|
||||
std::string next_index;
|
||||
if (num == 1) {
|
||||
const size_t first_index = 1;
|
||||
const size_t second_index = 2;
|
||||
if (num == first_index) {
|
||||
// The first element in HyperMap is func_graph
|
||||
error_index = "first";
|
||||
next_index = "second";
|
||||
} else if (num == 2) {
|
||||
} else if (num == second_index) {
|
||||
error_index = "second";
|
||||
next_index = "third";
|
||||
} else {
|
||||
error_index = std::to_string(num) + "th";
|
||||
next_index = std::to_string(num + 1) + "th";
|
||||
}
|
||||
return {error_index, next_index};
|
||||
return std::pair<std::string, std::string>(error_index, next_index);
|
||||
}
|
||||
|
||||
AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
|
||||
|
@ -137,9 +139,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<List> &type, const FuncGraph
|
|||
for (auto &item : arg_map) {
|
||||
num++;
|
||||
auto lhs = std::static_pointer_cast<List>(item.second);
|
||||
std::vector<std::string> indexes = GetHyperMapInputIndex(num);
|
||||
std::string error_index = indexes[0];
|
||||
std::string next_index = indexes[1];
|
||||
auto [error_index, next_index] = GetHyperMapInputIndex(num);
|
||||
if (lhs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The " << error_index << " element in HyperMap has wrong type, expected a List, but got "
|
||||
<< item.second->ToString() << ".";
|
||||
|
@ -199,9 +199,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr<Tuple> &type, const FuncGrap
|
|||
for (auto &item : arg_map) {
|
||||
num++;
|
||||
auto lhs = std::static_pointer_cast<Tuple>(item.second);
|
||||
std::vector<std::string> indexes = GetHyperMapInputIndex(num);
|
||||
std::string error_index = indexes[0];
|
||||
std::string next_index = indexes[1];
|
||||
auto [error_index, next_index] = GetHyperMapInputIndex(num);
|
||||
if (lhs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The " << error_index << " element in HyperMap has wrong type, expected a Tuple, but got "
|
||||
<< item.second->ToString() << ".";
|
||||
|
@ -317,6 +315,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_a
|
|||
<< trace::GetDebugInfo(func_graph->debug_info()) << "\n";
|
||||
int64_t idx = 0;
|
||||
std::string str_index = "first";
|
||||
const size_t diff_index = 2;
|
||||
for (auto &item : arg_map) {
|
||||
// The first element in HyperMap is func_graph
|
||||
if (idx == 0) {
|
||||
|
@ -324,7 +323,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_a
|
|||
} else if (idx == 1) {
|
||||
str_index = "third";
|
||||
} else {
|
||||
str_index = std::to_string(idx + 2) + "th";
|
||||
str_index = std::to_string(idx + diff_index) + "th";
|
||||
}
|
||||
++idx;
|
||||
oss << "The type of the " << str_index << " argument in HyperMap is " << item.second->ToString() << ".\n";
|
||||
|
|
|
@ -81,7 +81,7 @@ class HyperMap : public MetaFuncGraph {
|
|||
const ArgsPairList &arg_map);
|
||||
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
|
||||
ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
|
||||
std::vector<std::string> GetHyperMapInputIndex(size_t num);
|
||||
std::pair<std::string, std::string> GetHyperMapInputIndex(size_t num);
|
||||
|
||||
MultitypeFuncGraphPtr fn_leaf_;
|
||||
bool reverse_;
|
||||
|
|
|
@ -65,21 +65,23 @@ FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) {
|
|||
return ptrGraph;
|
||||
}
|
||||
|
||||
std::vector<std::string> Map::GetMapInputIndex(size_t num) {
|
||||
std::pair<std::string, std::string> Map::GetMapInputIndex(size_t num) {
|
||||
std::string error_index;
|
||||
std::string next_index;
|
||||
if (num == 1) {
|
||||
const size_t first_index = 1;
|
||||
const size_t second_index = 2;
|
||||
if (num == first_index) {
|
||||
// The first element in Map is func_graph
|
||||
error_index = "first";
|
||||
next_index = "second";
|
||||
} else if (num == 2) {
|
||||
} else if (num == second_index) {
|
||||
error_index = "second";
|
||||
next_index = "third";
|
||||
} else {
|
||||
error_index = std::to_string(num) + "th";
|
||||
next_index = std::to_string(num + 1) + "th";
|
||||
}
|
||||
return {error_index, next_index};
|
||||
return std::pair<std::string, std::string>(error_index, next_index);
|
||||
}
|
||||
|
||||
AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph,
|
||||
|
@ -94,9 +96,7 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr<List> &type, const FuncGraphP
|
|||
for (auto &item : arg_pairs) {
|
||||
num++;
|
||||
auto lhs = std::dynamic_pointer_cast<List>(item.second);
|
||||
std::vector<std::string> indexes = GetMapInputIndex(num);
|
||||
std::string error_index = indexes[0];
|
||||
std::string next_index = indexes[1];
|
||||
auto [error_index, next_index] = GetMapInputIndex(num);
|
||||
if (lhs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The " << error_index << " element in Map has wrong type, expected a List, but got "
|
||||
<< item.second->ToString() << ".";
|
||||
|
@ -157,9 +157,7 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGrap
|
|||
for (auto &item : arg_pairs) {
|
||||
num++;
|
||||
auto lhs = std::dynamic_pointer_cast<Tuple>(item.second);
|
||||
std::vector<std::string> indexes = GetMapInputIndex(num);
|
||||
std::string error_index = indexes[0];
|
||||
std::string next_index = indexes[1];
|
||||
auto [error_index, next_index] = GetMapInputIndex(num);
|
||||
if (lhs == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The " << error_index << " element in Map has wrong type, expected a Tuple, but got "
|
||||
<< item.second->ToString() << ".";
|
||||
|
|
|
@ -78,7 +78,7 @@ class Map : public MetaFuncGraph {
|
|||
AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
|
||||
const ArgsPairList &arg_pairs);
|
||||
AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs);
|
||||
std::vector<std::string> GetMapInputIndex(size_t num);
|
||||
std::pair<std::string, std::string> GetMapInputIndex(size_t num);
|
||||
void Init() {
|
||||
if (fn_leaf_ != nullptr) {
|
||||
name_ = "map[" + fn_leaf_->name() + "]";
|
||||
|
|
|
@ -118,7 +118,7 @@ class Tensor(Tensor_):
|
|||
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method.
|
||||
if init is None:
|
||||
validator.check_value_type('input_data', input_data,
|
||||
(Tensor_, np.ndarray, list, tuple, float, int, bool, complex), 'Tensor')
|
||||
(Tensor_, np.ndarray, np.str_, list, tuple, float, int, bool, complex), 'Tensor')
|
||||
valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64,
|
||||
np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128)
|
||||
if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \
|
||||
|
|
Loading…
Reference in New Issue