forked from mindspore-Ecosystem/mindspore
!25245 Move TypeId2String and DtypeToTypeId to dtype_extends.cc
Merge pull request !25245 from DeshiChen/1019_typeid2string
This commit is contained in:
commit
b27991a2af
|
@ -187,7 +187,7 @@ class CNodeDecoder {
|
|||
inputs.push_back(nodes_map_[name]);
|
||||
}
|
||||
input_formats_.push_back(input_desc[kJsonKeyFormat]);
|
||||
input_types_.push_back(DtypeToTypeId(input_desc[kJsonKeyDataType]));
|
||||
input_types_.push_back(StringToTypeId(input_desc[kJsonKeyDataType]));
|
||||
input_shapes_.push_back(input_desc[kJsonKeyShape]);
|
||||
}
|
||||
// new cnode.
|
||||
|
@ -205,7 +205,7 @@ class CNodeDecoder {
|
|||
// single output.
|
||||
nlohmann::json output_desc = output_descs[0];
|
||||
output_formats_.push_back(output_desc[kJsonKeyFormat]);
|
||||
output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
|
||||
output_types_.push_back(StringToTypeId(output_desc[kJsonKeyDataType]));
|
||||
output_shapes_.push_back(output_desc[kJsonKeyShape]);
|
||||
nodes_map_[output_desc[kJsonKeyTensorName]] = cnode_;
|
||||
} else {
|
||||
|
@ -213,7 +213,7 @@ class CNodeDecoder {
|
|||
for (size_t j = 0; j < output_descs.size(); ++j) {
|
||||
nlohmann::json output_desc = output_descs[j];
|
||||
output_formats_.push_back(output_desc[kJsonKeyFormat]);
|
||||
output_types_.push_back(DtypeToTypeId(output_desc[kJsonKeyDataType]));
|
||||
output_types_.push_back(StringToTypeId(output_desc[kJsonKeyDataType]));
|
||||
output_shapes_.push_back(output_desc[kJsonKeyShape]);
|
||||
auto get_item =
|
||||
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_, NewValueNode(SizeToLong(j))});
|
||||
|
@ -282,7 +282,7 @@ class CNodeDecoder {
|
|||
}
|
||||
|
||||
tensor::TensorPtr DecodeScalar(const nlohmann::json &scalar_json) const {
|
||||
auto type_id = DtypeToTypeId(scalar_json[kJsonKeyDataType]);
|
||||
auto type_id = StringToTypeId(scalar_json[kJsonKeyDataType]);
|
||||
switch (type_id) {
|
||||
case kNumberTypeFloat16:
|
||||
return std::make_shared<tensor::Tensor>(static_cast<float>(scalar_json[kJsonKeyValue]), kFloat16);
|
||||
|
@ -310,7 +310,7 @@ class CNodeDecoder {
|
|||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
// layout info.
|
||||
builder->SetOutputsFormat(std::vector<std::string>{value_json[kJsonKeyFormat]});
|
||||
builder->SetOutputsDeviceType(std::vector<TypeId>{DtypeToTypeId(value_json[kJsonKeyDataType])});
|
||||
builder->SetOutputsDeviceType(std::vector<TypeId>{StringToTypeId(value_json[kJsonKeyDataType])});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), value_node.get());
|
||||
func_graph->AddValueNode(value_node);
|
||||
MS_LOG(DEBUG) << "decode value node success, " << value_node->DebugString(2);
|
||||
|
@ -340,7 +340,7 @@ ParameterPtr AkgKernelJsonDecoder::DecodeParameter(const nlohmann::json ¶met
|
|||
std::string name = parameter_json[kJsonKeyTensorName];
|
||||
new_parameter->set_name(name);
|
||||
std::string format = parameter_json[kJsonKeyFormat];
|
||||
TypeId dtype = DtypeToTypeId(parameter_json[kJsonKeyDataType]);
|
||||
TypeId dtype = StringToTypeId(parameter_json[kJsonKeyDataType]);
|
||||
ShapeVector shape = AbstractShapeCreator::GetFakeAbstractShape(parameter_json[kJsonKeyShape], format);
|
||||
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(dtype), shape);
|
||||
new_parameter->set_abstract(abstract);
|
||||
|
|
|
@ -193,7 +193,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
|
|||
std::vector<nlohmann::json> input_list;
|
||||
for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
|
||||
auto type_id = this->GetInputDataType(anf_node, real_input_index);
|
||||
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
|
||||
std::string dtype = TypeIdToString(type_id, true);
|
||||
if (dtype.empty()) {
|
||||
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] input [" << real_input_index
|
||||
<< "] data type is null. ";
|
||||
|
@ -234,7 +234,7 @@ bool AkgKernelJsonGenerator::CreateOutputDescJson(const AnfNodePtr &anf_node, co
|
|||
for (size_t i = 0; i < output_tensor_num; i++) {
|
||||
nlohmann::json output_json;
|
||||
auto type_id = this->GetOutputDataType(anf_node, i);
|
||||
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
|
||||
std::string dtype = TypeIdToString(type_id, true);
|
||||
if (dtype.empty()) {
|
||||
MS_LOG(ERROR) << "Op [" << anf_node->fullname_with_scope() << "] output [" << i << "] data type is null. ";
|
||||
return false;
|
||||
|
@ -271,7 +271,7 @@ void AkgKernelJsonGenerator::GetAttrJson(const AnfNodePtr &anf_node, const std::
|
|||
(*attr_json)[kJsonKeyValue] = get_int_value(attr_value);
|
||||
} else if (type == "str") {
|
||||
if (attr_value->isa<Type>()) {
|
||||
(*attr_json)[kJsonKeyValue] = TypeId2String(attr_value->cast<TypePtr>()->type_id());
|
||||
(*attr_json)[kJsonKeyValue] = TypeIdToString(attr_value->cast<TypePtr>()->type_id(), true);
|
||||
} else {
|
||||
(*attr_json)[kJsonKeyValue] = GetValue<std::string>(attr_value);
|
||||
}
|
||||
|
@ -723,7 +723,7 @@ nlohmann::json AkgKernelJsonGenerator::CreateInputsJson(const std::vector<AnfNod
|
|||
for (size_t i = 0; i < input_index.size(); ++i) {
|
||||
auto tmp_input = input_index[i];
|
||||
auto type_id = this->GetInputDataType(tmp_input.first, tmp_input.second.first);
|
||||
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
|
||||
std::string dtype = TypeIdToString(type_id, true);
|
||||
nlohmann::json input_desc_json;
|
||||
input_desc_json[kJsonKeyTensorName] =
|
||||
GetTensorName(node_json_map.at(tmp_input.first), kJsonKeyInputDesc, tmp_input.second);
|
||||
|
@ -823,7 +823,7 @@ nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNo
|
|||
}
|
||||
if (!found) {
|
||||
auto type_id = this->GetOutputDataType(tmp_output.first, tmp_output.second);
|
||||
std::string dtype = TypeId2String(type_id, dump_option_.is_before_select_kernel);
|
||||
std::string dtype = TypeIdToString(type_id, true);
|
||||
output_desc_json[kJsonKeyTensorName] =
|
||||
GetTensorName(node_json_map.at(tmp_output.first), kJsonKeyOutputDesc, std::make_pair(0, tmp_output.second));
|
||||
output_desc_json[kJsonKeyDataType] = dtype;
|
||||
|
|
|
@ -36,41 +36,6 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
constexpr char kAxis[] = "axis";
|
||||
constexpr char kTypeInt32[] = "Int32";
|
||||
const std::unordered_map<std::string, TypeId> type_id_maps = {{"float", TypeId::kNumberTypeFloat32},
|
||||
{"float16", TypeId::kNumberTypeFloat16},
|
||||
{"float32", TypeId::kNumberTypeFloat32},
|
||||
{"float64", TypeId::kNumberTypeFloat64},
|
||||
{"int", TypeId::kNumberTypeInt},
|
||||
{"int8", TypeId::kNumberTypeInt8},
|
||||
{"int16", TypeId::kNumberTypeInt16},
|
||||
{"int32", TypeId::kNumberTypeInt32},
|
||||
{"int64", TypeId::kNumberTypeInt64},
|
||||
{"uint", TypeId::kNumberTypeUInt},
|
||||
{"uint8", TypeId::kNumberTypeUInt8},
|
||||
{"uint16", TypeId::kNumberTypeUInt16},
|
||||
{"uint32", TypeId::kNumberTypeUInt32},
|
||||
{"uint64", TypeId::kNumberTypeUInt64},
|
||||
{"bool", TypeId::kNumberTypeBool},
|
||||
{"complex64", TypeId::kNumberTypeComplex64},
|
||||
{"complex128", TypeId::kNumberTypeComplex128}};
|
||||
|
||||
const std::map<TypeId, std::string> type_id_str_map = {{TypeId::kNumberTypeFloat32, "float32"},
|
||||
{TypeId::kNumberTypeFloat16, "float16"},
|
||||
{TypeId::kNumberTypeFloat, "float"},
|
||||
{TypeId::kNumberTypeFloat64, "float64"},
|
||||
{TypeId::kNumberTypeInt, "int"},
|
||||
{TypeId::kNumberTypeInt8, "int8"},
|
||||
{TypeId::kNumberTypeInt16, "int16"},
|
||||
{TypeId::kNumberTypeInt32, "int32"},
|
||||
{TypeId::kNumberTypeInt64, "int64"},
|
||||
{TypeId::kNumberTypeUInt, "uint"},
|
||||
{TypeId::kNumberTypeUInt8, "uint8"},
|
||||
{TypeId::kNumberTypeUInt16, "uint16"},
|
||||
{TypeId::kNumberTypeUInt32, "uint32"},
|
||||
{TypeId::kNumberTypeUInt64, "uint64"},
|
||||
{TypeId::kNumberTypeBool, "bool"},
|
||||
{TypeId::kNumberTypeComplex64, "complex64"},
|
||||
{TypeId::kNumberTypeComplex128, "complex128"}};
|
||||
|
||||
const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
|
||||
{"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
|
||||
|
@ -276,24 +241,10 @@ KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &pro
|
|||
}
|
||||
|
||||
TypeId DtypeToTypeId(const std::string &dtypes) {
|
||||
auto iter = type_id_maps.find(dtypes);
|
||||
if (iter != type_id_maps.end()) {
|
||||
return iter->second;
|
||||
} else {
|
||||
MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
|
||||
if (dtypes == "float") {
|
||||
return TypeId::kNumberTypeFloat32;
|
||||
}
|
||||
}
|
||||
|
||||
std::string TypeId2String(TypeId type_id, bool unknown_as_default) {
|
||||
auto iter = type_id_str_map.find(type_id);
|
||||
if (iter == type_id_str_map.end()) {
|
||||
if (!unknown_as_default) {
|
||||
MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
|
||||
}
|
||||
MS_LOG(INFO) << "Using default dtype: float32";
|
||||
return "float32";
|
||||
}
|
||||
return iter->second;
|
||||
return StringToTypeId(dtypes);
|
||||
}
|
||||
|
||||
std::string Dtype2ShortType(const std::string &dtype) {
|
||||
|
|
|
@ -79,7 +79,6 @@ KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &pro
|
|||
KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor);
|
||||
TypeId DtypeToTypeId(const std::string &dtypes);
|
||||
std::string Dtype2ShortType(const std::string &dtypes);
|
||||
std::string TypeId2String(TypeId type_id, bool unknown_as_default = false);
|
||||
size_t GetDtypeNbyte(const std::string &dtypes);
|
||||
bool GetShapeSize(const std::vector<size_t> &shape, const TypePtr &type_ptr, int64_t *size_i);
|
||||
bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
|
||||
|
|
|
@ -60,7 +60,7 @@ void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.push_back(in_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(in_shape_tmp.size()));
|
||||
type_list_.push_back(TypeId2String(input_type_list[i]));
|
||||
type_list_.push_back(TypeIdToString(input_type_list[i], true));
|
||||
}
|
||||
|
||||
num_output_ = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
|
@ -77,7 +77,7 @@ void CustomAOTCpuKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.push_back(out_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(out_shape_tmp.size()));
|
||||
type_list_.push_back(TypeId2String(output_type_list[i]));
|
||||
type_list_.push_back(TypeIdToString(output_type_list[i], true));
|
||||
}
|
||||
|
||||
std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
|
||||
|
|
|
@ -121,7 +121,7 @@ class CustomAOTGpuKernel : public GpuKernel {
|
|||
[&in_shape_tmp](size_t c) { in_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.push_back(in_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(in_shape_tmp.size()));
|
||||
type_list_.push_back(TypeId2String(input_type_list[i]));
|
||||
type_list_.push_back(TypeIdToString(input_type_list[i], true));
|
||||
}
|
||||
|
||||
num_output_ = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
|
@ -140,7 +140,7 @@ class CustomAOTGpuKernel : public GpuKernel {
|
|||
[&out_shape_tmp](size_t c) { out_shape_tmp.push_back(SizeToLong(c)); });
|
||||
shape_list_.push_back(out_shape_tmp);
|
||||
ndims_.push_back(SizeToInt(out_shape_tmp.size()));
|
||||
type_list_.push_back(TypeId2String(output_type_list[i]));
|
||||
type_list_.push_back(TypeIdToString(output_type_list[i], true));
|
||||
}
|
||||
|
||||
std::transform(std::begin(shape_list_), std::end(shape_list_), std::back_inserter(shapes_),
|
||||
|
|
|
@ -62,12 +62,12 @@ std::string GpuKernelFactory::SupportedTypeList(const std::string &kernel_name)
|
|||
std::string type_list = "in[";
|
||||
auto attr = (iter->second)[attr_index].first;
|
||||
for (size_t input_index = 0; input_index < attr.GetInputSize(); ++input_index) {
|
||||
type_list = type_list + TypeId2String(attr.GetInputAttr(input_index).first) +
|
||||
type_list = type_list + TypeIdToString(attr.GetInputAttr(input_index).first) +
|
||||
((input_index == (attr.GetInputSize() - 1)) ? "" : " ");
|
||||
}
|
||||
type_list = type_list + "], out[";
|
||||
for (size_t input_index = 0; input_index < attr.GetOutputSize(); ++input_index) {
|
||||
type_list = type_list + TypeId2String(attr.GetOutputAttr(input_index).first) +
|
||||
type_list = type_list + TypeIdToString(attr.GetOutputAttr(input_index).first) +
|
||||
((input_index == (attr.GetOutputSize() - 1)) ? "" : " ");
|
||||
}
|
||||
type_lists = type_lists + type_list + "]; ";
|
||||
|
|
|
@ -37,7 +37,7 @@ void Node::DumpTensor(std::ostringstream &os) const {
|
|||
os << shape[i];
|
||||
if (i + 1 < shape.size()) os << ",";
|
||||
}
|
||||
os << "]{" << TypeIdToType(type)->ToString() << "x" << format << "}";
|
||||
os << "]{" << TypeIdToString(type) << "x" << format << "}";
|
||||
}
|
||||
|
||||
void Node::AddInput(const NodePtr &new_input) {
|
||||
|
|
|
@ -33,7 +33,6 @@
|
|||
#include "mindspore/core/ir/tensor.h"
|
||||
#include "mindspore/core/utils/shape_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel::inner {
|
||||
enum class NType {
|
||||
|
|
|
@ -306,7 +306,7 @@ TypeId CastOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
if (dst_type->isa<Type>()) {
|
||||
return dst_type->cast<TypePtr>()->type_id();
|
||||
}
|
||||
return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
|
||||
return StringToTypeId(GetValue<std::string>(dst_type));
|
||||
}
|
||||
|
||||
void SelectOp::CheckType(const NodePtrList &inputs, const DAttrs &) {
|
||||
|
@ -439,7 +439,7 @@ TypeId Conv2dOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
if (dst_type->isa<Type>()) {
|
||||
return dst_type->cast<TypePtr>()->type_id();
|
||||
}
|
||||
return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
|
||||
return StringToTypeId(GetValue<std::string>(dst_type));
|
||||
}
|
||||
|
||||
DShape TransposeOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
|
@ -492,13 +492,12 @@ DShape MatMulOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
|||
}
|
||||
|
||||
TypeId MatMulOp::InferType(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
CHECK_ATTR(attrs, "dst_type");
|
||||
if (attrs.find("dst_type") == attrs.end()) return inputs[0]->type;
|
||||
auto dst_type = attrs.find("dst_type")->second;
|
||||
if (dst_type->isa<Type>()) {
|
||||
return dst_type->cast<TypePtr>()->type_id();
|
||||
}
|
||||
return kernel::DtypeToTypeId(GetValue<std::string>(dst_type));
|
||||
return StringToTypeId(GetValue<std::string>(dst_type));
|
||||
}
|
||||
|
||||
DShape PadAkgOp::InferShape(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include <functional>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/model/node.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "ir/dtype/type.h"
|
||||
|
||||
namespace mindspore::graphkernel::inner {
|
||||
|
|
|
@ -89,12 +89,12 @@ std::string SupportedTypeList(const CNodePtr &kernel_node) {
|
|||
auto supported_akg_type_out = kernel_info_list[i]->GetAllOutputDeviceTypes();
|
||||
std::string supported_akg_type_list = "in[";
|
||||
for (auto type : supported_akg_type) {
|
||||
supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type);
|
||||
supported_akg_type_list = supported_akg_type_list + TypeIdToString(type);
|
||||
}
|
||||
supported_type_lists = supported_type_lists + supported_akg_type_list + "], out[";
|
||||
supported_akg_type_list.clear();
|
||||
for (auto type : supported_akg_type_out) {
|
||||
supported_akg_type_list = supported_akg_type_list + mindspore::kernel::TypeId2String(type);
|
||||
supported_akg_type_list = supported_akg_type_list + TypeIdToString(type);
|
||||
}
|
||||
supported_type_lists = supported_type_lists + supported_akg_type_list + "]; ";
|
||||
}
|
||||
|
@ -381,10 +381,10 @@ void PrintUnsupportedTypeException(const CNodePtr &kernel_node, const std::vecto
|
|||
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
std::string build_type = "in [";
|
||||
std::for_each(std::begin(inputs_type), std::end(inputs_type),
|
||||
[&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
|
||||
[&build_type](auto i) { build_type += TypeIdToString(i) + " "; });
|
||||
build_type += "] out [";
|
||||
std::for_each(std::begin(outputs_type), std::end(outputs_type),
|
||||
[&build_type](auto i) { build_type += mindspore::kernel::TypeId2String(i) + " "; });
|
||||
[&build_type](auto i) { build_type += TypeIdToString(i) + " "; });
|
||||
build_type += "]";
|
||||
auto supported_type_lists = SupportedTypeList(kernel_node);
|
||||
MS_EXCEPTION(TypeError) << "Select GPU kernel op[" << kernel_name
|
||||
|
|
|
@ -48,6 +48,14 @@ namespace mindspore {
|
|||
/// \return The shared_ptr of Type.
|
||||
MS_CORE_API TypePtr TypeIdToType(TypeId id);
|
||||
|
||||
/// \brief Get the type string according to a TypeId.
|
||||
///
|
||||
/// \param[in] id Define a TypeId.
|
||||
/// \param[in] to_lower Whether convert the string to lowercase.
|
||||
///
|
||||
/// \return The string of Type.
|
||||
MS_CORE_API std::string TypeIdToString(TypeId id, bool to_lower = false);
|
||||
|
||||
/// \brief String defines a type of string.
|
||||
class MS_CORE_API String : public Object {
|
||||
public:
|
||||
|
@ -372,6 +380,13 @@ TypePtr Clone(const T &t) {
|
|||
/// \return The shared_ptr of type.
|
||||
MS_CORE_API TypePtr StringToType(const std::string &type_name);
|
||||
|
||||
/// \brief Get the TypeId of Type according to a string of type name.
|
||||
///
|
||||
/// \param[in] type_name Define a string of type name.
|
||||
///
|
||||
/// \return The TypeId of type.
|
||||
MS_CORE_API TypeId StringToTypeId(const std::string &type_name);
|
||||
|
||||
/// \brief Given a type x and a base type, judge whether x is the base type or is a subclass of the base type.
|
||||
///
|
||||
/// \param[in] x Define the type to be judged.
|
||||
|
|
|
@ -82,6 +82,24 @@ TypePtr TypeIdToType(TypeId id) {
|
|||
return it->second;
|
||||
}
|
||||
|
||||
std::string TypeIdToString(TypeId id, bool to_lower) {
|
||||
switch (id) {
|
||||
case TypeId::kNumberTypeFloat:
|
||||
return "float";
|
||||
case TypeId::kNumberTypeInt:
|
||||
return "int";
|
||||
case TypeId::kNumberTypeUInt:
|
||||
return "uint";
|
||||
default:
|
||||
break;
|
||||
}
|
||||
auto type = TypeIdToType(id)->ToString();
|
||||
if (to_lower) {
|
||||
std::transform(type.begin(), type.end(), type.begin(), [](auto c) { return static_cast<char>(std::tolower(c)); });
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) {
|
||||
|
@ -279,6 +297,7 @@ TypePtr GetTypeByFullString(const std::string &type_name) {
|
|||
{"EnvType", std::make_shared<EnvType>()},
|
||||
{"Number", std::make_shared<Number>()},
|
||||
{"Bool", std::make_shared<Bool>()},
|
||||
{"bool", std::make_shared<Bool>()},
|
||||
{"Slice", std::make_shared<Slice>()},
|
||||
{"Dictionary", std::make_shared<Dictionary>()},
|
||||
{"String", std::make_shared<String>()},
|
||||
|
@ -298,10 +317,15 @@ TypePtr GetTypeByStringStarts(const std::string &type_name) {
|
|||
return r.compare(0, cmp_len, l, 0, cmp_len) < 0;
|
||||
}
|
||||
};
|
||||
static std::map<std::string, std::function<TypePtr(const std::string &type_name)>, name_cmp> type_map = {
|
||||
static std::map<std::string, std::function<TypePtr(const std::string &)>, name_cmp> type_map = {
|
||||
{"Int", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Int>(type_name, "Int"); }},
|
||||
{"int", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Int>(type_name, "int"); }},
|
||||
{"UInt", [](const std::string &type_name) -> TypePtr { return StringToNumberType<UInt>(type_name, "UInt"); }},
|
||||
{"uint", [](const std::string &type_name) -> TypePtr { return StringToNumberType<UInt>(type_name, "uint"); }},
|
||||
{"Float", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Float>(type_name, "Float"); }},
|
||||
{"float", [](const std::string &type_name) -> TypePtr { return StringToNumberType<Float>(type_name, "float"); }},
|
||||
{"Complex", [](const std::string &tname) -> TypePtr { return StringToNumberType<Complex>(tname, "Complex"); }},
|
||||
{"complex", [](const std::string &tname) -> TypePtr { return StringToNumberType<Complex>(tname, "complex"); }},
|
||||
{"Tensor", [](const std::string &type_name) -> TypePtr { return TensorStrToType(type_name); }},
|
||||
{"Undetermined", [](const std::string &type_name) -> TypePtr { return UndeterminedStrToType(type_name); }},
|
||||
{"RowTensor", [](const std::string &type_name) -> TypePtr { return RowTensorStrToType(type_name); }},
|
||||
|
@ -330,6 +354,8 @@ TypePtr StringToType(const std::string &type_name) {
|
|||
return type;
|
||||
}
|
||||
|
||||
TypeId StringToTypeId(const std::string &type_name) { return StringToType(type_name)->type_id(); }
|
||||
|
||||
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
|
||||
if (x == nullptr || base_type == nullptr) {
|
||||
MS_LOG(ERROR) << "Type is nullptr.";
|
||||
|
|
Loading…
Reference in New Issue