forked from OSSInnovation/mindspore
Add converter method for operator 'Split'.
This commit is contained in:
parent
9da592a99f
commit
5953855106
|
@ -418,6 +418,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
|
|||
return NewPrimitiveC<StridedSlice>(prim, inputs, quantType);
|
||||
} else if (op_type == "Cast") {
|
||||
return NewPrimitiveC<Cast>(prim, inputs, quantType);
|
||||
} else if (op_type == "Split") {
|
||||
return NewPrimitiveC<Split>(prim, inputs, quantType);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
|
||||
|
|
|
@ -29,6 +29,37 @@ void Split::SetSizeSplits(const std::vector<int> &size_splits) {
|
|||
}
|
||||
void Split::SetSplitDim(int split_dim) { this->primitive_->value.AsSplit()->splitDim = split_dim; }
|
||||
|
||||
int Split::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_Split;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_Split) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::SplitT();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
attr->splitDim = GetValue<int32_t>(prim.GetAttr("axis"));
|
||||
attr->numberSplit = GetValue<int32_t>(prim.GetAttr("output_num"));
|
||||
this->primitive_->value.value = attr;
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive value is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int Split::GetNumberSplit() const { return this->primitive_->value_as_Split()->numberSplit(); }
|
||||
|
@ -99,12 +130,14 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
|
|||
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
|
||||
int split_dim_i = input_shape[split_dim];
|
||||
// support split size is -1 in the end.
|
||||
if (i == number_split - 1 && size_split[i] == -1) {
|
||||
if (size_split.empty()) {
|
||||
split_dim_i = input_shape[split_dim] / number_split;
|
||||
} else if (i == number_split - 1 && size_split[i] == -1) {
|
||||
for (size_t j = 0; j < size_split.size() - 1; ++j) {
|
||||
split_dim_i -= size_split[j];
|
||||
}
|
||||
} else {
|
||||
split_dim_i = size_split.empty() ? input_shape[split_dim] / number_split : size_split[i];
|
||||
split_dim_i = size_split[i];
|
||||
}
|
||||
output_shape[split_dim] = split_dim_i;
|
||||
outputs_[i]->set_shape(output_shape);
|
||||
|
|
|
@ -35,6 +35,7 @@ class Split : public PrimitiveC {
|
|||
void SetNumberSplit(int number_split);
|
||||
void SetSizeSplits(const std::vector<int> &size_splits);
|
||||
void SetSplitDim(int split_dim);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
|
||||
#else
|
||||
Split() = default;
|
||||
|
||||
|
|
|
@ -419,10 +419,13 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode,
|
|||
MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << ".";
|
||||
}
|
||||
#endif
|
||||
} else if (value->isa<Number>()) {
|
||||
MS_LOG(INFO) << "Value is a number.";
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support value type , need add support.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(ERROR) << "Not support value type , need add support.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue