forked from OSSInnovation/mindspore
Add converter method for operator 'Split'.
This commit is contained in:
parent
9da592a99f
commit
5953855106
|
@ -418,6 +418,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
||||||
return NewPrimitiveC<StridedSlice>(prim, inputs, quantType);
|
return NewPrimitiveC<StridedSlice>(prim, inputs, quantType);
|
||||||
} else if (op_type == "Cast") {
|
} else if (op_type == "Cast") {
|
||||||
return NewPrimitiveC<Cast>(prim, inputs, quantType);
|
return NewPrimitiveC<Cast>(prim, inputs, quantType);
|
||||||
|
} else if (op_type == "Split") {
|
||||||
|
return NewPrimitiveC<Split>(prim, inputs, quantType);
|
||||||
|
|
||||||
#ifdef SUPPORT_TRAIN
|
#ifdef SUPPORT_TRAIN
|
||||||
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
|
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
|
||||||
|
|
|
@ -29,6 +29,37 @@ void Split::SetSizeSplits(const std::vector<int> &size_splits) {
|
||||||
}
|
}
|
||||||
void Split::SetSplitDim(int split_dim) { this->primitive_->value.AsSplit()->splitDim = split_dim; }
|
void Split::SetSplitDim(int split_dim) { this->primitive_->value.AsSplit()->splitDim = split_dim; }
|
||||||
|
|
||||||
|
int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||||
|
if (this->primitive_ == nullptr) {
|
||||||
|
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||||
|
if (this->primitive_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new primitiveT failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
this->primitive_->value.type = schema::PrimitiveType_Split;
|
||||||
|
}
|
||||||
|
if (this->primitive_->value.type != schema::PrimitiveType_Split) {
|
||||||
|
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
auto attr = new (std::nothrow) schema::SplitT();
|
||||||
|
if (attr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
attr->splitDim = GetValue<int32_t>(prim.GetAttr("axis"));
|
||||||
|
attr->numberSplit = GetValue<int32_t>(prim.GetAttr("output_num"));
|
||||||
|
this->primitive_->value.value = attr;
|
||||||
|
if (this->primitive_->value.value == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); }
|
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); }
|
||||||
|
@ -99,12 +130,14 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
||||||
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
|
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
|
||||||
int split_dim_i = input_shape[split_dim];
|
int split_dim_i = input_shape[split_dim];
|
||||||
// support split size is -1 in the end.
|
// support split size is -1 in the end.
|
||||||
if (i == number_split - 1 && size_split[i] == -1) {
|
if (size_split.empty()) {
|
||||||
|
split_dim_i = input_shape[split_dim] / number_split;
|
||||||
|
} else if (i == number_split - 1 && size_split[i] == -1) {
|
||||||
for (size_t j = 0; j < size_split.size() - 1; ++j) {
|
for (size_t j = 0; j < size_split.size() - 1; ++j) {
|
||||||
split_dim_i -= size_split[j];
|
split_dim_i -= size_split[j];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i];
|
split_dim_i = size_split[i];
|
||||||
}
|
}
|
||||||
output_shape[split_dim] = split_dim_i;
|
output_shape[split_dim] = split_dim_i;
|
||||||
outputs_[i]->set_shape(output_shape);
|
outputs_[i]->set_shape(output_shape);
|
||||||
|
|
|
@ -35,6 +35,7 @@ class Split : public PrimitiveC {
|
||||||
void SetNumberSplit(int number_split);
|
void SetNumberSplit(int number_split);
|
||||||
void SetSizeSplits(const std::vector<int> &size_splits);
|
void SetSizeSplits(const std::vector<int> &size_splits);
|
||||||
void SetSplitDim(int split_dim);
|
void SetSplitDim(int split_dim);
|
||||||
|
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||||
#else
|
#else
|
||||||
Split() = default;
|
Split() = default;
|
||||||
|
|
||||||
|
|
|
@ -419,6 +419,9 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
||||||
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
|
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
} else if (value->isa<Number>()) {
|
||||||
|
MS_LOG(INFO) << "Value is a number.";
|
||||||
|
return RET_OK;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Not support value type , need add support.";
|
MS_LOG(ERROR) << "Not support value type , need add support.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
Loading…
Reference in New Issue