forked from mindspore-Ecosystem/mindspore
!8686 [lite] adjust minir model attr changes
From: @xu_anyue Reviewed-by: @hangangqiang Signed-off-by:
This commit is contained in:
commit
c253c3a5c6
|
@ -51,7 +51,7 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
||||||
MS_LOG(INFO) << "BiasAdd's attr axis is set to default";
|
MS_LOG(INFO) << "BiasAdd's attr axis is set to default";
|
||||||
attr->axis = {1};
|
attr->axis = {1};
|
||||||
} else {
|
} else {
|
||||||
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
|
attr->axis = CastToInt(prim.GetAttr("axis"), true);
|
||||||
}
|
}
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
if (this->primitive_->value.value == nullptr) {
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
|
|
@ -49,7 +49,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
|
||||||
MS_LOG(WARNING) << "get axis failed";
|
MS_LOG(WARNING) << "get axis failed";
|
||||||
attr->axis = {0};
|
attr->axis = {0};
|
||||||
} else {
|
} else {
|
||||||
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
|
attr->axis = CastToInt(prim.GetAttr("axis"), true);
|
||||||
}
|
}
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
if (this->primitive_->value.value == nullptr) {
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
|
|
@ -51,7 +51,7 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
||||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
|
auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front();
|
||||||
attr->axis = prim_axis;
|
attr->axis = prim_axis;
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
if (this->primitive_->value.value == nullptr) {
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
|
|
@ -139,21 +139,21 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
||||||
} else {
|
} else {
|
||||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||||
}
|
}
|
||||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||||
attr->padUp = pad_list[0];
|
attr->padUp = pad_list[0];
|
||||||
attr->padDown = pad_list[1];
|
attr->padDown = pad_list[1];
|
||||||
attr->padLeft = pad_list[2];
|
attr->padLeft = pad_list[2];
|
||||||
attr->padRight = pad_list[3];
|
attr->padRight = pad_list[3];
|
||||||
|
|
||||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||||
attr->dilateH = dilation[0];
|
attr->dilateH = dilation[0];
|
||||||
attr->dilateW = dilation[1];
|
attr->dilateW = dilation[1];
|
||||||
|
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||||
attr->kernelH = kernel_size[0];
|
attr->kernelH = kernel_size[0];
|
||||||
attr->kernelW = kernel_size[1];
|
attr->kernelW = kernel_size[1];
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||||
attr->strideH = stride[2];
|
attr->strideH = stride[2];
|
||||||
attr->strideW = stride[3];
|
attr->strideW = stride[3];
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ void Conv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT
|
||||||
|
|
||||||
int channel_mutiplier = 1;
|
int channel_mutiplier = 1;
|
||||||
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
||||||
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
|
||||||
}
|
}
|
||||||
attr->channelMultiplier = channel_mutiplier;
|
attr->channelMultiplier = channel_mutiplier;
|
||||||
|
|
||||||
|
@ -212,25 +212,25 @@ void Conv2D::PopulaterConv2DSingleGroup(const Primitive &prim, schema::Primitive
|
||||||
} else {
|
} else {
|
||||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||||
}
|
}
|
||||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||||
attr->padUp = pad_list[0];
|
attr->padUp = pad_list[0];
|
||||||
attr->padDown = pad_list[1];
|
attr->padDown = pad_list[1];
|
||||||
attr->padLeft = pad_list[2];
|
attr->padLeft = pad_list[2];
|
||||||
attr->padRight = pad_list[3];
|
attr->padRight = pad_list[3];
|
||||||
|
|
||||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||||
attr->dilateH = dilation[0];
|
attr->dilateH = dilation[0];
|
||||||
attr->dilateW = dilation[1];
|
attr->dilateW = dilation[1];
|
||||||
|
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||||
attr->kernelH = kernel_size[0];
|
attr->kernelH = kernel_size[0];
|
||||||
attr->kernelW = kernel_size[1];
|
attr->kernelW = kernel_size[1];
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||||
attr->strideH = stride[2];
|
attr->strideH = stride[2];
|
||||||
attr->strideW = stride[3];
|
attr->strideW = stride[3];
|
||||||
|
|
||||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
||||||
|
|
||||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||||
if (pad_mode == "valid") {
|
if (pad_mode == "valid") {
|
||||||
|
@ -270,7 +270,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
||||||
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
|
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
int group = GetValue<int>(groupAttr);
|
int group = CastToInt(groupAttr, false).front();
|
||||||
if (group > 1) {
|
if (group > 1) {
|
||||||
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
|
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -94,7 +94,7 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
|
||||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
attr->group = GetValue<int>(prim.GetAttr("group"));
|
attr->group = CastToInt(prim.GetAttr("group"), false).front();
|
||||||
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
|
||||||
if (format == "NCHW") {
|
if (format == "NCHW") {
|
||||||
attr->format = schema::Format_NCHW;
|
attr->format = schema::Format_NCHW;
|
||||||
|
@ -103,25 +103,25 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
|
||||||
} else {
|
} else {
|
||||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||||
}
|
}
|
||||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||||
attr->padUp = pad_list[0];
|
attr->padUp = pad_list[0];
|
||||||
attr->padDown = pad_list[1];
|
attr->padDown = pad_list[1];
|
||||||
attr->padLeft = pad_list[2];
|
attr->padLeft = pad_list[2];
|
||||||
attr->padRight = pad_list[3];
|
attr->padRight = pad_list[3];
|
||||||
|
|
||||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||||
attr->dilateH = dilation[0];
|
attr->dilateH = dilation[0];
|
||||||
attr->dilateW = dilation[1];
|
attr->dilateW = dilation[1];
|
||||||
|
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||||
attr->kernelH = kernel_size[0];
|
attr->kernelH = kernel_size[0];
|
||||||
attr->kernelW = kernel_size[1];
|
attr->kernelW = kernel_size[1];
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||||
attr->strideH = stride[0];
|
attr->strideH = stride[0];
|
||||||
attr->strideW = stride[1];
|
attr->strideW = stride[1];
|
||||||
|
|
||||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
||||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||||
if (pad_mode == "valid") {
|
if (pad_mode == "valid") {
|
||||||
attr->padMode = schema::PadMode_VALID;
|
attr->padMode = schema::PadMode_VALID;
|
||||||
|
|
|
@ -92,7 +92,7 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
||||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
attr->group = GetValue<int>(prim.GetAttr("group"));
|
attr->group = CastToInt(prim.GetAttr("group"), false).front();
|
||||||
if (attr->group > 1) {
|
if (attr->group > 1) {
|
||||||
this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput;
|
this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput;
|
||||||
}
|
}
|
||||||
|
@ -104,25 +104,25 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
||||||
} else {
|
} else {
|
||||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||||
}
|
}
|
||||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||||
attr->padUp = pad_list[0];
|
attr->padUp = pad_list[0];
|
||||||
attr->padDown = pad_list[1];
|
attr->padDown = pad_list[1];
|
||||||
attr->padLeft = pad_list[2];
|
attr->padLeft = pad_list[2];
|
||||||
attr->padRight = pad_list[3];
|
attr->padRight = pad_list[3];
|
||||||
|
|
||||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||||
attr->dilateH = dilation[0];
|
attr->dilateH = dilation[0];
|
||||||
attr->dilateW = dilation[1];
|
attr->dilateW = dilation[1];
|
||||||
|
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||||
attr->kernelH = kernel_size[0];
|
attr->kernelH = kernel_size[0];
|
||||||
attr->kernelW = kernel_size[1];
|
attr->kernelW = kernel_size[1];
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||||
attr->strideH = stride[0];
|
attr->strideH = stride[0];
|
||||||
attr->strideW = stride[1];
|
attr->strideW = stride[1];
|
||||||
|
|
||||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
||||||
|
|
||||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||||
if (pad_mode == "valid") {
|
if (pad_mode == "valid") {
|
||||||
|
|
|
@ -132,21 +132,21 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
|
||||||
} else {
|
} else {
|
||||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||||
}
|
}
|
||||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||||
attr->padUp = pad_list[0];
|
attr->padUp = pad_list[0];
|
||||||
attr->padDown = pad_list[1];
|
attr->padDown = pad_list[1];
|
||||||
attr->padLeft = pad_list[2];
|
attr->padLeft = pad_list[2];
|
||||||
attr->padRight = pad_list[3];
|
attr->padRight = pad_list[3];
|
||||||
|
|
||||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||||
attr->dilateH = dilation[0];
|
attr->dilateH = dilation[0];
|
||||||
attr->dilateW = dilation[1];
|
attr->dilateW = dilation[1];
|
||||||
|
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||||
attr->kernelH = kernel_size[0];
|
attr->kernelH = kernel_size[0];
|
||||||
attr->kernelW = kernel_size[1];
|
attr->kernelW = kernel_size[1];
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||||
attr->strideH = stride[0];
|
attr->strideH = stride[0];
|
||||||
attr->strideW = stride[1];
|
attr->strideW = stride[1];
|
||||||
|
|
||||||
|
@ -168,7 +168,7 @@ void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::Primitiv
|
||||||
|
|
||||||
int channel_mutiplier = 1;
|
int channel_mutiplier = 1;
|
||||||
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
if (prim.GetAttr("channel_mutiplier") != nullptr) {
|
||||||
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
channel_mutiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
|
||||||
}
|
}
|
||||||
attr->channelMultiplier = channel_mutiplier;
|
attr->channelMultiplier = channel_mutiplier;
|
||||||
|
|
||||||
|
@ -195,25 +195,25 @@ void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::Primi
|
||||||
} else {
|
} else {
|
||||||
attr->format = schema::Format_NUM_OF_FORMAT;
|
attr->format = schema::Format_NUM_OF_FORMAT;
|
||||||
}
|
}
|
||||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
|
auto pad_list = CastToInt(prim.GetAttr("pad_list"), true);
|
||||||
attr->padUp = pad_list[0];
|
attr->padUp = pad_list[0];
|
||||||
attr->padDown = pad_list[1];
|
attr->padDown = pad_list[1];
|
||||||
attr->padLeft = pad_list[2];
|
attr->padLeft = pad_list[2];
|
||||||
attr->padRight = pad_list[3];
|
attr->padRight = pad_list[3];
|
||||||
|
|
||||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||||
attr->dilateH = dilation[0];
|
attr->dilateH = dilation[0];
|
||||||
attr->dilateW = dilation[1];
|
attr->dilateW = dilation[1];
|
||||||
|
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||||
attr->kernelH = kernel_size[0];
|
attr->kernelH = kernel_size[0];
|
||||||
attr->kernelW = kernel_size[1];
|
attr->kernelW = kernel_size[1];
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||||
attr->strideH = stride[0];
|
attr->strideH = stride[0];
|
||||||
attr->strideW = stride[1];
|
attr->strideW = stride[1];
|
||||||
|
|
||||||
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
|
attr->channelOut = CastToInt(prim.GetAttr("out_channel"), false).front();
|
||||||
|
|
||||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
|
||||||
if (pad_mode == "valid" || pad_mode == "VALID") {
|
if (pad_mode == "valid" || pad_mode == "VALID") {
|
||||||
|
@ -248,7 +248,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
|
||||||
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
int group = GetValue<int>(prim.GetAttr("group"));
|
int group = CastToInt(prim.GetAttr("group"), false).front();
|
||||||
if (group == 1) {
|
if (group == 1) {
|
||||||
PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
|
PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
|
||||||
} else if (group > 1) {
|
} else if (group > 1) {
|
||||||
|
|
|
@ -86,27 +86,27 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
||||||
} else {
|
} else {
|
||||||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||||
}
|
}
|
||||||
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pads"));
|
auto pad_list = CastToInt(prim.GetAttr("pads"), true);
|
||||||
attr->padUp = pad_list[0];
|
attr->padUp = pad_list[0];
|
||||||
attr->padDown = pad_list[1];
|
attr->padDown = pad_list[1];
|
||||||
attr->padLeft = pad_list[2];
|
attr->padLeft = pad_list[2];
|
||||||
attr->padRight = pad_list[3];
|
attr->padRight = pad_list[3];
|
||||||
|
|
||||||
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
|
auto dilation = CastToInt(prim.GetAttr("dilation"), true);
|
||||||
attr->dilateH = dilation[0];
|
attr->dilateH = dilation[0];
|
||||||
attr->dilateW = dilation[1];
|
attr->dilateW = dilation[1];
|
||||||
|
|
||||||
if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) {
|
if (utils::isa<ValueSequeue>(prim.GetAttr("kernel_size"))) {
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), true);
|
||||||
attr->kernelH = kernel_size[0];
|
attr->kernelH = kernel_size[0];
|
||||||
attr->kernelW = kernel_size[1];
|
attr->kernelW = kernel_size[1];
|
||||||
} else {
|
} else {
|
||||||
auto kernel_size = GetValue<int>(prim.GetAttr("kernel_size"));
|
auto kernel_size = CastToInt(prim.GetAttr("kernel_size"), false).front();
|
||||||
attr->kernelH = kernel_size;
|
attr->kernelH = kernel_size;
|
||||||
attr->kernelW = kernel_size;
|
attr->kernelW = kernel_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
|
auto stride = CastToInt(prim.GetAttr("stride"), true);
|
||||||
attr->strideH = stride[2];
|
attr->strideH = stride[2];
|
||||||
attr->strideW = stride[3];
|
attr->strideW = stride[3];
|
||||||
|
|
||||||
|
@ -124,7 +124,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
|
||||||
} else {
|
} else {
|
||||||
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
attr->activationType = schema::ActivationType_NO_ACTIVATION;
|
||||||
}
|
}
|
||||||
auto channel_multiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
|
auto channel_multiplier = CastToInt(prim.GetAttr("channel_multiplier"), false).front();
|
||||||
attr->channelMultiplier = channel_multiplier;
|
attr->channelMultiplier = channel_multiplier;
|
||||||
|
|
||||||
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
|
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
|
||||||
|
|
|
@ -53,7 +53,7 @@ int ExpandDims::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
||||||
// use axis instead of dim
|
// use axis instead of dim
|
||||||
if (inputs[1]->isa<ValueNode>()) {
|
if (inputs[1]->isa<ValueNode>()) {
|
||||||
auto axis_tensor = inputs[1]->cast<ValueNodePtr>();
|
auto axis_tensor = inputs[1]->cast<ValueNodePtr>();
|
||||||
int axis = GetValue<int>(axis_tensor->value());
|
int axis = CastToInt(axis_tensor->value(), false).front();
|
||||||
attr->dim = axis;
|
attr->dim = axis;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "input axis is not value node.";
|
MS_LOG(ERROR) << "input axis is not value node.";
|
||||||
|
|
|
@ -59,7 +59,7 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
||||||
}
|
}
|
||||||
if (inputs[2]->isa<ValueNode>()) {
|
if (inputs[2]->isa<ValueNode>()) {
|
||||||
ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>();
|
ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>();
|
||||||
int axis = GetValue<int>(axis_tensor->value());
|
int axis = CastToInt(axis_tensor->value(), false).front();
|
||||||
gather_attr->axis = axis;
|
gather_attr->axis = axis;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "input axis is not value node.";
|
MS_LOG(ERROR) << "input axis is not value node.";
|
||||||
|
|
|
@ -48,7 +48,7 @@ int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
||||||
}
|
}
|
||||||
attr->axis = -1;
|
attr->axis = -1;
|
||||||
if (prim.GetAttr("axis") != nullptr) {
|
if (prim.GetAttr("axis") != nullptr) {
|
||||||
attr->axis = GetValue<int>(prim.GetAttr("axis"));
|
attr->axis = CastToInt(prim.GetAttr("axis"), false).front();
|
||||||
}
|
}
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
if (this->primitive_->value.value == nullptr) {
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
|
|
@ -110,11 +110,11 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
||||||
attr->padMode = schema::PadMode_NOTSET;
|
attr->padMode = schema::PadMode_NOTSET;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize"));
|
auto kernel_size = CastToInt(prim.GetAttr("ksize"), true);
|
||||||
attr->windowH = kernel_size[2];
|
attr->windowH = kernel_size[2];
|
||||||
attr->windowW = kernel_size[3];
|
attr->windowW = kernel_size[3];
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides"));
|
auto stride = CastToInt(prim.GetAttr("strides"), true);
|
||||||
attr->strideH = stride[2];
|
attr->strideH = stride[2];
|
||||||
attr->strideW = stride[3];
|
attr->strideW = stride[3];
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
|
|
|
@ -99,11 +99,11 @@ int PoolingGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
|
||||||
attr->padMode = schema::PadMode_NOTSET;
|
attr->padMode = schema::PadMode_NOTSET;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("ksize"));
|
auto kernel_size = CastToInt(prim.GetAttr("ksize"), true);
|
||||||
attr->windowH = kernel_size[2];
|
attr->windowH = kernel_size[2];
|
||||||
attr->windowW = kernel_size[3];
|
attr->windowW = kernel_size[3];
|
||||||
|
|
||||||
auto stride = GetValue<std::vector<int>>(prim.GetAttr("strides"));
|
auto stride = CastToInt(prim.GetAttr("strides"), true);
|
||||||
attr->strideH = stride[2];
|
attr->strideH = stride[2];
|
||||||
attr->strideW = stride[3];
|
attr->strideW = stride[3];
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
|
|
|
@ -180,6 +180,35 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
#ifdef PRIMITIVE_WRITEABLE
|
#ifdef PRIMITIVE_WRITEABLE
|
||||||
|
std::vector<int> CastToInt(const ValuePtr value, bool is_vector) {
|
||||||
|
if (value == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "valueptr is nullptr.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
std::vector<int> cur_value;
|
||||||
|
if (is_vector) {
|
||||||
|
if (!utils::isa<ValueSequeuePtr>(value)) {
|
||||||
|
MS_LOG(WARNING) << "valueptr is not a sequence, value may be a scalar.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
if (value->cast<ValueSequeuePtr>()->value().front()->type()->type_name() == "Int64Imm") {
|
||||||
|
auto origin_value = GetValue<std::vector<int64_t>>(value);
|
||||||
|
for (size_t index = 0; index < origin_value.size(); ++index) {
|
||||||
|
cur_value.push_back(static_cast<int>(origin_value[index]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cur_value = GetValue<std::vector<int>>(value);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (value->type_name() == "Int64Imm") {
|
||||||
|
cur_value.push_back(static_cast<int>(GetValue<int64_t>(value)));
|
||||||
|
} else {
|
||||||
|
cur_value.push_back(GetValue<int>(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cur_value;
|
||||||
|
}
|
||||||
|
|
||||||
void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) {
|
void PrimitiveC::CalFloatScopeByMeanAndStddev(const double &mean, const double &stdDev, float *mMin, float *mMax) {
|
||||||
const float qmin = 0;
|
const float qmin = 0;
|
||||||
const float qmax = 255;
|
const float qmax = 255;
|
||||||
|
|
|
@ -52,6 +52,8 @@ static std::map<std::string, schema::ActivationType> kActivationTypeMap{{"ReLU",
|
||||||
{"Sigmoid", schema::ActivationType_SIGMOID},
|
{"Sigmoid", schema::ActivationType_SIGMOID},
|
||||||
{"HSwish", schema::ActivationType_HSWISH},
|
{"HSwish", schema::ActivationType_HSWISH},
|
||||||
{"HSigmoid", schema::ActivationType_HSIGMOID}};
|
{"HSigmoid", schema::ActivationType_HSIGMOID}};
|
||||||
|
std::vector<int> CastToInt(const ValuePtr value, bool is_vector);
|
||||||
|
|
||||||
class PrimitiveC : public mindspore::Primitive {
|
class PrimitiveC : public mindspore::Primitive {
|
||||||
public:
|
public:
|
||||||
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().
|
// Argument primitive is deliverd into PrimitiveC and will be deleted in ~PrimitiveC().
|
||||||
|
|
|
@ -87,7 +87,7 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
||||||
attr->axes.emplace_back(elem->value());
|
attr->axes.emplace_back(elem->value());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int axes_item = GetValue<int>(value);
|
int axes_item = CastToInt(value, false).front();
|
||||||
attr->axes.push_back(axes_item);
|
attr->axes.push_back(axes_item);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
||||||
attr->shape.emplace_back(static_cast<int>(elem->value()));
|
attr->shape.emplace_back(static_cast<int>(elem->value()));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int dim = GetValue<int>(val);
|
int dim = CastToInt(val, false).front();
|
||||||
attr->shape = {dim};
|
attr->shape = {dim};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ int Resize::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
|
||||||
MS_LOG(ERROR) << "wrong resize type";
|
MS_LOG(ERROR) << "wrong resize type";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
std::vector<int> targetSize = GetValue<std::vector<int>>(prim.GetAttr("size"));
|
std::vector<int> targetSize = CastToInt(prim.GetAttr("size"), true);
|
||||||
attr->newHeight = targetSize[0];
|
attr->newHeight = targetSize[0];
|
||||||
attr->newWidth = targetSize[1];
|
attr->newWidth = targetSize[1];
|
||||||
attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners"));
|
attr->alignCorners = GetValue<bool>(prim.GetAttr("align_corners"));
|
||||||
|
|
|
@ -43,7 +43,7 @@ int SoftMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
||||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
|
auto prim_axis = CastToInt(prim.GetAttr("axis"), false).front();
|
||||||
attr->axis = prim_axis;
|
attr->axis = prim_axis;
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
if (this->primitive_->value.value == nullptr) {
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
|
|
@ -50,7 +50,7 @@ int Squeeze::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
||||||
MS_LOG(INFO) << "Squeeze's attr xis is set to default";
|
MS_LOG(INFO) << "Squeeze's attr xis is set to default";
|
||||||
attr->axis = {0};
|
attr->axis = {0};
|
||||||
} else {
|
} else {
|
||||||
attr->axis = GetValue<std::vector<int>>(prim.GetAttr("axis"));
|
attr->axis = CastToInt(prim.GetAttr("axis"), true);
|
||||||
}
|
}
|
||||||
this->primitive_->value.value = attr;
|
this->primitive_->value.value = attr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,11 +73,11 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr
|
||||||
MS_LOG(ERROR) << "new StridedSlice failed";
|
MS_LOG(ERROR) << "new StridedSlice failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask"));
|
attr->beginMask = CastToInt(prim.GetAttr("begin_mask"), false).front();
|
||||||
attr->endMask = GetValue<int>(prim.GetAttr("end_mask"));
|
attr->endMask = CastToInt(prim.GetAttr("end_mask"), false).front();
|
||||||
attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask"));
|
attr->ellipsisMask = CastToInt(prim.GetAttr("ellipsis_mask"), false).front();
|
||||||
attr->newAxisMask = GetValue<int>(prim.GetAttr("new_axis_mask"));
|
attr->newAxisMask = CastToInt(prim.GetAttr("new_axis_mask"), false).front();
|
||||||
attr->shrinkAxisMask = GetValue<int>(prim.GetAttr("shrink_axis_mask"));
|
attr->shrinkAxisMask = CastToInt(prim.GetAttr("shrink_axis_mask"), false).front();
|
||||||
auto inputNodeFirst = inputs[kAnfPopulaterOne];
|
auto inputNodeFirst = inputs[kAnfPopulaterOne];
|
||||||
std::vector<int> beginVec;
|
std::vector<int> beginVec;
|
||||||
GetAttrDataFromInput(inputNodeFirst, &beginVec);
|
GetAttrDataFromInput(inputNodeFirst, &beginVec);
|
||||||
|
|
|
@ -56,7 +56,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
|
||||||
MS_LOG(INFO) << "Tile's attr dims is set to default";
|
MS_LOG(INFO) << "Tile's attr dims is set to default";
|
||||||
attr->dims = {1};
|
attr->dims = {1};
|
||||||
} else {
|
} else {
|
||||||
attr->dims = GetValue<std::vector<int>>(prim.GetAttr("dims"));
|
attr->dims = CastToInt(prim.GetAttr("dims"), true);
|
||||||
}
|
}
|
||||||
if (inputs.size() == kAnfPopulaterTwo) {
|
if (inputs.size() == kAnfPopulaterTwo) {
|
||||||
auto inputNode = inputs[kAnfPopulaterOne];
|
auto inputNode = inputs[kAnfPopulaterOne];
|
||||||
|
@ -75,7 +75,7 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
|
||||||
attr->multiples.emplace_back(elem->value());
|
attr->multiples.emplace_back(elem->value());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int multiple = GetValue<int>(value);
|
int multiple = CastToInt(value, false).front();
|
||||||
attr->multiples = {multiple};
|
attr->multiples = {multiple};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ int UnsortedSegmentSum::UnPackAttr(const Primitive &prim, const std::vector<AnfN
|
||||||
std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>();
|
std::unique_ptr<schema::UnsortedSegmentSumT> attr = std::make_unique<schema::UnsortedSegmentSumT>();
|
||||||
if (inputs[2]->isa<ValueNode>()) {
|
if (inputs[2]->isa<ValueNode>()) {
|
||||||
ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value();
|
ValuePtr value = inputs[2]->cast<ValueNodePtr>()->value();
|
||||||
attr->numSegments = GetValue<int>(value);
|
attr->numSegments = CastToInt(value, false).front();
|
||||||
this->primitive_->value.value = attr.release();
|
this->primitive_->value.value = attr.release();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -314,7 +314,9 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, s
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto input_index_key =
|
auto input_index_key =
|
||||||
get_item_input_cnode->fullname_with_scope() + "_o:" + std::to_string(GetValue<int>(value_node->value()));
|
get_item_input_cnode->fullname_with_scope() + "_o:" +
|
||||||
|
std::to_string(value_node->value()->type_name() == "Int64Imm" ? GetValue<int64_t>(value_node->value())
|
||||||
|
: GetValue<int>(value_node->value()));
|
||||||
auto iter = node_id_map_.find(input_index_key);
|
auto iter = node_id_map_.find(input_index_key);
|
||||||
if (iter == node_id_map_.end()) {
|
if (iter == node_id_map_.end()) {
|
||||||
#ifdef SUPPORT_TRAIN
|
#ifdef SUPPORT_TRAIN
|
||||||
|
|
Loading…
Reference in New Issue