forked from mindspore-Ecosystem/mindspore
!8686 [lite] adjust minir model attr changes
From: @xu_anyue Reviewed-by: @hangangqiang Signed-off-by:
This commit is contained in:
commit
c253c3a5c6
|
@ -51,7 +51,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
|||
MS_LOG(INFO) << "BiasAdd's attr axis is set to default";
|
||||
attr->axis = {1};
|
||||
} else {
|
||||
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
|
||||
attr->axis = CastToInt(prim.GetAttr("axis"), true);
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
|
|
|
@ -49,7 +49,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
|
|||
MS_LOG(WARNING) << "get axis failed";
|
||||
attr->axis = {0};
|
||||
} else {
|
||||
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
|
||||
attr->axis = CastToInt(prim.GetAttr("axis"), true);
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
|
|
|
@ -51,7 +51,7 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
|
||||
auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front();
|
||||
attr->axis = prim_axis;
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
|
|
|
@ -139,21 +139,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
|||
} else {
|
||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
|
@ -175,7 +175,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
|||
|
||||
int channel_mutiplier = 1;
|
||||
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
||||
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
||||
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
|
||||
}
|
||||
attr->channelMultiplier = channel_mutiplier;
|
||||
|
||||
|
@ -212,25 +212,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive
|
|||
} else {
|
||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
||||
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
|
@ -270,7 +270,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
int group = GetValue<int>(groupAttr);
|
||||
int group = CastToInt(groupAttr, false).front();
|
||||
if (group > 1) {
|
||||
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
|
||||
} else {
|
||||
|
|
|
@ -94,7 +94,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
|
|||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->group = GetValue<int>(prim.GetAttr("group"));
|
||||
attr->group = CastToInt(prim.GetAttr("group"), false).front();
|
||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||
if (format == "NCHW") {
|
||||
attr->format = schema::Format_NCHW;
|
||||
|
@ -103,25 +103,25 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
|
|||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||
attr->strideH = stride[0];
|
||||
attr->strideW = stride[1];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
||||
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
|
|
|
@ -92,7 +92,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
|||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->group = GetValue<int>(prim.GetAttr("group"));
|
||||
attr->group = CastToInt(prim.GetAttr("group"), false).front();
|
||||
if (attr->group > 1) {
|
||||
this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput;
|
||||
}
|
||||
|
@ -104,25 +104,25 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
|||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||
attr->strideH = stride[0];
|
||||
attr->strideW = stride[1];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
||||
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid") {
|
||||
|
|
|
@ -132,21 +132,21 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
|
|||
} else {
|
||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||
attr->strideH = stride[0];
|
||||
attr->strideW = stride[1];
|
||||
|
||||
|
@ -168,7 +168,7 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
|
|||
|
||||
int channel_mutiplier = 1;
|
||||
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
||||
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
||||
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
|
||||
}
|
||||
attr->channelMultiplier = channel_mutiplier;
|
||||
|
||||
|
@ -195,25 +195,25 @@ void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::Primi
|
|||
} else {
|
||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
||||
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||
attr->strideH = stride[0];
|
||||
attr->strideW = stride[1];
|
||||
|
||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
||||
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||
if (pad_mode == "valid" || pad_mode == "VALID") {
|
||||
|
@ -248,7 +248,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
|
|||
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
int group = GetValue<int>(prim.GetAttr("group"));
|
||||
int group = CastToInt(prim.GetAttr("group"), false).front();
|
||||
if (group == 1) {
|
||||
PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
|
||||
} else if (group > 1) {
|
||||
|
|
|
@ -86,27 +86,27 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
|||
} else {
|
||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pads"));
|
||||
auto pad_list = CastToInt(prim.GetAttr("pads"), true);
|
||||
attr->padUp = pad_list[0];
|
||||
attr->padDown = pad_list[1];
|
||||
attr->padLeft = pad_list[2];
|
||||
attr->padRight = pad_list[3];
|
||||
|
||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
||||
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||
attr->dilateH = dilation[0];
|
||||
attr->dilateW = dilation[1];
|
||||
|
||||
if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) {
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||
attr->kernelH = kernel_size[0];
|
||||
attr->kernelW = kernel_size[1];
|
||||
} else {
|
||||
auto kernel_size = GetValue<int>(prim.GetAttr("kernel_size"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), false).front();
|
||||
attr->kernelH = kernel_size;
|
||||
attr->kernelW = kernel_size;
|
||||
}
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
||||
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
|
||||
|
@ -124,7 +124,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
|||
} else {
|
||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||
}
|
||||
auto channel_multiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
||||
auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
|
||||
attr->channelMultiplier = channel_multiplier;
|
||||
|
||||
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
|
||||
|
|
|
@ -53,7 +53,7 @@ int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
|||
// use axis instead of dim
|
||||
if (inputs[1]->isa<ValueNode>()) {
|
||||
auto axis_tensor = inputs[1]->cast<ValueNodePtr>();
|
||||
int axis = GetValue<int>(axis_tensor->value());
|
||||
int axis = CastToInt(axis_tensor->value(), false).front();
|
||||
attr->dim = axis;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "input axis is not value node.";
|
||||
|
|
|
@ -59,7 +59,7 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
}
|
||||
if (inputs[2]->isa<ValueNode>()) {
|
||||
ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>();
|
||||
int axis = GetValue<int>(axis_tensor->value());
|
||||
int axis = CastToInt(axis_tensor->value(), false).front();
|
||||
gather_attr->axis = axis;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "input axis is not value node.";
|
||||
|
|
|
@ -48,7 +48,7 @@ int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
}
|
||||
attr->axis = -1;
|
||||
if (prim.GetAttr("axis") != nullptr) {
|
||||
attr->axis = GetValue<int>(prim.GetAttr("axis"));
|
||||
attr->axis = CastToInt(prim.GetAttr("axis"), false).front();
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
|
|
|
@ -110,11 +110,11 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
|||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("ksize"), true);
|
||||
attr->windowH = kernel_size[2];
|
||||
attr->windowW = kernel_size[3];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides"));
|
||||
auto stride = CastToInt(prim.GetAttr("strides"), true);
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
this->primitive_->value.value = attr;
|
||||
|
|
|
@ -99,11 +99,11 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
|||
attr->padMode = schema::PadMode_NOTSET;
|
||||
}
|
||||
|
||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize"));
|
||||
auto kernel_size = CastToInt(prim.GetAttr("ksize"), true);
|
||||
attr->windowH = kernel_size[2];
|
||||
attr->windowW = kernel_size[3];
|
||||
|
||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides"));
|
||||
auto stride = CastToInt(prim.GetAttr("strides"), true);
|
||||
attr->strideH = stride[2];
|
||||
attr->strideW = stride[3];
|
||||
this->primitive_->value.value = attr;
|
||||
|
|
|
@ -180,6 +180,35 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
std::vector<int> CastToInt(const ValuePtr value, bool is_vector) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(WARNING) << "valueptr is nullptr.";
|
||||
return {};
|
||||
}
|
||||
std::vector<int> cur_value;
|
||||
if (is_vector) {
|
||||
if (!utils::isa<ValueSequeuePtr>(value)) {
|
||||
MS_LOG(WARNING) << "valueptr is not a sequence, value may be a scalar.";
|
||||
return {};
|
||||
}
|
||||
if (value->cast<ValueSequeuePtr>()->value().front()->type()->type_name() == "Int64Imm") {
|
||||
auto origin_value = GetValue<std::vector<int64_t>>(value);
|
||||
for (size_t index = 0; index < origin_value.size(); ++index) {
|
||||
cur_value.push_back(static_cast<int>(origin_value[index]));
|
||||
}
|
||||
} else {
|
||||
cur_value = GetValue<std::vector<int>>(value);
|
||||
}
|
||||
} else {
|
||||
if (value->type_name() == "Int64Imm") {
|
||||
cur_value.push_back(static_cast<int>(GetValue<int64_t>(value)));
|
||||
} else {
|
||||
cur_value.push_back(GetValue<int>(value));
|
||||
}
|
||||
}
|
||||
return cur_value;
|
||||
}
|
||||
|
||||
void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) {
|
||||
const float qmin = 0;
|
||||
const float qmax = 255;
|
||||
|
|
|
@ -52,6 +52,8 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU",
|
|||
{"Sigmoid", schema::ActivationType_SIGMOID},
|
||||
{"HSwish", schema::ActivationType_HSWISH},
|
||||
{"HSigmoid", schema::ActivationType_HSIGMOID}};
|
||||
std::vector<int> CastToInt(const ValuePtr value, bool is_vector);
|
||||
|
||||
class PrimitiveC : public mindspore::Primitive {
|
||||
public:
|
||||
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().
|
||||
|
|
|
@ -87,7 +87,7 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
attr->axes.emplace_back(elem->value());
|
||||
}
|
||||
} else {
|
||||
int axes_item = GetValue<int>(value);
|
||||
int axes_item = CastToInt(value, false).front();
|
||||
attr->axes.push_back(axes_item);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
|||
attr->shape.emplace_back(static_cast<int>(elem->value()));
|
||||
}
|
||||
} else {
|
||||
int dim = GetValue<int>(val);
|
||||
int dim = CastToInt(val, false).front();
|
||||
attr->shape = {dim};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
|||
MS_LOG(ERROR) << "wrong resize type";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<int> targetSize = GetValue<std::vector<int>>(prim.GetAttr("size"));
|
||||
std::vector<int> targetSize = CastToInt(prim.GetAttr("size"), true);
|
||||
attr->newHeight = targetSize[0];
|
||||
attr->newWidth = targetSize[1];
|
||||
attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners"));
|
||||
|
|
|
@ -43,7 +43,7 @@ int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
|||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
|
||||
auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front();
|
||||
attr->axis = prim_axis;
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
|
|
|
@ -50,7 +50,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
|||
MS_LOG(INFO) << "Squeeze's attr xis is set to default";
|
||||
attr->axis = {0};
|
||||
} else {
|
||||
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
|
||||
attr->axis = CastToInt(prim.GetAttr("axis"), true);
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
}
|
||||
|
|
|
@ -73,11 +73,11 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr
|
|||
MS_LOG(ERROR) << "new StridedSlice failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask"));
|
||||
attr->endMask = GetValue<int>(prim.GetAttr("end_mask"));
|
||||
attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask"));
|
||||
attr->newAxisMask = GetValue<int>(prim.GetAttr("new_axis_mask"));
|
||||
attr->shrinkAxisMask = GetValue<int>(prim.GetAttr("shrink_axis_mask"));
|
||||
attr->beginMask = CastToInt(prim.GetAttr("begin_mask"), false).front();
|
||||
attr->endMask = CastToInt(prim.GetAttr("end_mask"), false).front();
|
||||
attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask"), false).front();
|
||||
attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask"), false).front();
|
||||
attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask"), false).front();
|
||||
auto inputNodeFirst = inputs[kAnfPopulaterOne];
|
||||
std::vector<int> beginVec;
|
||||
GetAttrDataFromInput(inputNodeFirst, &beginVec);
|
||||
|
|
|
@ -56,7 +56,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
|
|||
MS_LOG(INFO) << "Tile's attr dims is set to default";
|
||||
attr->dims = {1};
|
||||
} else {
|
||||
attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims"));
|
||||
attr->dims = CastToInt(prim.GetAttr("dims"), true);
|
||||
}
|
||||
if (inputs.size() == kAnfPopulaterTwo) {
|
||||
auto inputNode = inputs[kAnfPopulaterOne];
|
||||
|
@ -75,7 +75,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
|
|||
attr->multiples.emplace_back(elem->value());
|
||||
}
|
||||
} else {
|
||||
int multiple = GetValue<int>(value);
|
||||
int multiple = CastToInt(value, false).front();
|
||||
attr->multiples = {multiple};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfN
|
|||
std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>();
|
||||
if (inputs[2]->isa<ValueNode>()) {
|
||||
ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value();
|
||||
attr->numSegments = GetValue<int>(value);
|
||||
attr->numSegments = CastToInt(value, false).front();
|
||||
this->primitive_->value.value = attr.release();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -314,7 +314,9 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto input_index_key =
|
||||
get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(GetValue<int>(value_node->value()));
|
||||
get_item_input_cnode->fullname_with_scope() + "_o:" +
|
||||
std::to_string(value_node->value()->type_name() == "Int64Imm" ? GetValue<int64_t>(value_node->value())
|
||||
: GetValue<int>(value_node->value()));
|
||||
auto iter = node_id_map_.find(input_index_key);
|
||||
if (iter == node_id_map_.end()) {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
|
|
Loading…
Reference in New Issue