!6092 No broadcast when simplifying constants multiplication

Merge pull request !6092 from thlinh/dev_Sep10_no_broadcast_constant_mul
This commit is contained in:
mindspore-ci-bot 2020-09-12 14:46:48 +08:00 committed by Gitee
commit 7a52e30e45
1 changed files with 40 additions and 25 deletions

View File

@ -615,7 +615,7 @@ class PConstant : public PBase<PConstant<T> > {
return new_vnode;
}
// x is not nullptr
if (x->isa<CNode>()) {
if (x->isa<CNode>() || x->isa<Parameter>()) {
if ((x->abstract() == nullptr) || !x->abstract()->isa<abstract::AbstractTensor>()) {
return nullptr;
}
@ -650,8 +650,9 @@ class PConstant : public PBase<PConstant<T> > {
ret = memcpy_s(data, mem_size, source_data, mem_size);
}
if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size"
<< new_tensor_ptr->DataSize();
MS_LOG(INFO) << "memcpy_s error, error no " << ret << ", source size " << mem_size << "dest size"
<< new_tensor_ptr->DataSize();
return nullptr;
}
auto new_vnode = NewValueNode(new_tensor_ptr);
new_vnode->set_abstract(new_tensor_ptr->ToAbstract());
@ -735,46 +736,60 @@ class PConstant : public PBase<PConstant<T> > {
auto tensor_1_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_2_abstract = vnode_1->abstract()->cast<abstract::AbstractTensorPtr>();
auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType();
TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType();
TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
(tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
return nullptr;
ShapeVector tensor_out_shape;
int data_out_size;
tensor::TensorPtr new_tensor_ptr;
if ((tensor_1_abstract->shape()->shape() == tensor_2_abstract->shape()->shape()) &&
(tensor_1_type_ptr->type_id() == tensor_2_type_ptr->type_id())) {
// If two constant nodes have the same shape, then create a new one with this shape
tensor_out_shape = tensor_1_abstract->shape()->shape();
data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_1_type_ptr->type_id(), tensor_out_shape);
} else {
// If two constant nodes have different shapes, then create a new one node with the shape of the 3rd node
auto tensor_3_abstract = node_3->abstract()->cast<abstract::AbstractTensorPtr>();
TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType();
if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) ||
(tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) {
return nullptr;
}
tensor_out_shape = tensor_3_abstract->shape()->shape();
data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
return nullptr;
}
if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
return nullptr;
}
new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
}
ShapeVector tensor_out_shape = tensor_3_abstract->shape()->shape();
int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies<int>());
if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) {
return nullptr;
}
if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) {
return nullptr;
}
auto new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_3_type_ptr->type_id(), tensor_out_shape);
size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
size_t mem_size = GetTypeByte(new_tensor_ptr->Dtype()) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
int ret = 0;
void *data_out = nullptr;
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) {
if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat32) ||
(new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat)) {
Multiply<float>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<float *>(data_out);
} else {
if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) {
if (new_tensor_ptr->data_type() == TypeId::kNumberTypeFloat64) {
Multiply<double>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
ret = memcpy_s(data, mem_size, data_out, mem_size);
delete[] reinterpret_cast<double *>(data_out);
} else {
if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) ||
(tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) {
if ((new_tensor_ptr->data_type() == TypeId::kNumberTypeInt32) ||
(new_tensor_ptr->data_type() == TypeId::kNumberTypeInt)) {
Multiply<int>(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(),
tensor_ptr_2->DataSize(), &data_out, data_out_size);
ret = memcpy_s(data, mem_size, data_out, mem_size);