Add ConvertShapePtrToShapeMap
This commit is contained in:
parent
0a870440e0
commit
3a71a6218a
|
@ -396,6 +396,20 @@ std::vector<int64_t> CheckAndConvertUtils::ConvertShapePtrToShape(const std::str
|
||||||
return shape_element->shape();
|
return shape_element->shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) {
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (!shape->isa<abstract::Shape>()) {
|
||||||
|
return std::map<std::string, std::vector<int64_t>>();
|
||||||
|
}
|
||||||
|
auto shape_element = shape->cast<abstract::ShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_element);
|
||||||
|
ShapeMap shape_map;
|
||||||
|
shape_map[kShape] = shape_element->shape();
|
||||||
|
shape_map[kMinShape] = shape_element->min_shape();
|
||||||
|
shape_map[kMaxShape] = shape_element->max_shape();
|
||||||
|
return shape_map;
|
||||||
|
}
|
||||||
|
|
||||||
void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type,
|
void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type,
|
||||||
const string &value_name, int64_t value, const string &prim_name,
|
const string &value_name, int64_t value, const string &prim_name,
|
||||||
ExceptionType exception_type) {
|
ExceptionType exception_type) {
|
||||||
|
|
|
@ -30,6 +30,10 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
typedef std::pair<std::map<std::string, int64_t>, std::map<int64_t, std::string>> AttrConverterPair;
|
typedef std::pair<std::map<std::string, int64_t>, std::map<int64_t, std::string>> AttrConverterPair;
|
||||||
|
typedef std::map<std::string, std::vector<int64_t>> ShapeMap;
|
||||||
|
constexpr auto kShape = "shape";
|
||||||
|
constexpr auto kMinShape = "min_shape";
|
||||||
|
constexpr auto kMaxShape = "max_shape";
|
||||||
|
|
||||||
enum CompareEnum : int64_t {
|
enum CompareEnum : int64_t {
|
||||||
kEqual = 1, // ==
|
kEqual = 1, // ==
|
||||||
|
@ -234,6 +238,9 @@ class CheckAndConvertUtils {
|
||||||
|
|
||||||
static std::vector<int64_t> ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape,
|
static std::vector<int64_t> ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape,
|
||||||
const std::string &prim_name);
|
const std::string &prim_name);
|
||||||
|
|
||||||
|
static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape);
|
||||||
|
|
||||||
static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type,
|
static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type,
|
||||||
const std::string &value_name, int64_t value, const std::string &prim_name = "",
|
const std::string &value_name, int64_t value, const std::string &prim_name = "",
|
||||||
ExceptionType exception_type = ValueError);
|
ExceptionType exception_type = ValueError);
|
||||||
|
|
Loading…
Reference in New Issue