forked from mindspore-Ecosystem/mindspore
!8639 add dims parameter for tile ops
From: @liuwenhao4 Reviewed-by: @HilbertDavid,@zhanghaibo5,@zhang_xue_tong Signed-off-by:
This commit is contained in:
commit
b4773004fb
|
@ -24,6 +24,7 @@ typedef struct TileParameter {
|
|||
int in_dim_;
|
||||
int in_shape_[5];
|
||||
int out_shape_[5];
|
||||
int dims_[5];
|
||||
int multiples_[5];
|
||||
int in_strides_[5];
|
||||
int out_strides_[5];
|
||||
|
|
|
@ -31,10 +31,13 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive)
|
|||
memset(tile_param, 0, sizeof(TileParameter));
|
||||
tile_param->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::Tile *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
auto dims = param->GetDims();
|
||||
auto multiples = param->GetMultiples();
|
||||
tile_param->in_dim_ = multiples.size();
|
||||
for (int i = 0; i < tile_param->in_dim_; ++i) {
|
||||
tile_param->multiples_[i] = multiples[i];
|
||||
for (size_t i = 0; i < kDimension_4d; ++i) {
|
||||
tile_param->multiples_[i] = 1;
|
||||
}
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
tile_param->multiples_[dims[i]] = multiples[i];
|
||||
}
|
||||
return reinterpret_cast<OpParameter *>(tile_param);
|
||||
}
|
||||
|
|
|
@ -140,18 +140,17 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
|
|||
|
||||
std::vector<int> out_shape;
|
||||
std::vector<int> multiples = GetMultiples();
|
||||
std::vector<int> dims = GetDims();
|
||||
const size_t in_dims = input->shape().size();
|
||||
const size_t delta_dims = in_dims - multiples.size();
|
||||
|
||||
size_t i = 0;
|
||||
for (; i < delta_dims; ++i) {
|
||||
int tmp = input->shape()[i];
|
||||
out_shape.push_back(tmp);
|
||||
MS_ASSERT(multiples.size() == dims.size());
|
||||
for (size_t i = 0; i < in_dims; ++i) {
|
||||
out_shape.push_back(input->shape()[i]);
|
||||
}
|
||||
for (; i < in_dims; ++i) {
|
||||
int tmp = input->shape()[i] * (multiples[i - delta_dims]);
|
||||
out_shape.push_back(tmp);
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
out_shape[dims[i]] = input->shape()[dims[i]] * (multiples[i]);
|
||||
}
|
||||
|
||||
output->set_shape(out_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) {
|
|||
|
||||
int TileCPUKernel::ReSize() {
|
||||
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
|
||||
tile_parameter_->in_dim_ = in_tensors_[0]->shape().size();
|
||||
for (int i = 0; i < tile_parameter_->in_dim_; ++i) {
|
||||
tile_parameter_->in_shape_[i] = in_tensors_[0]->shape()[i];
|
||||
tile_parameter_->out_shape_[i] = out_tensors_[0]->shape()[i];
|
||||
|
|
Loading…
Reference in New Issue