!7943 fix sync code bug to support fm

Merge pull request !7943 from yangjie159/sync_code
This commit is contained in:
mindspore-ci-bot 2020-10-29 14:40:59 +08:00 committed by Gitee
commit db0315bff5
21 changed files with 581 additions and 5 deletions

View File

@ -0,0 +1,60 @@
/*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <math.h>
#include "nnacl/fp32_grad/binary_cross_entropy.h"
static void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const float *input_x,
const float *input_y, const float *weight, float *loss, float *tmp_loss) {
float epsilon = 1e-12;
if (reduction == 0) {
for (int i = 0; i < input_size; i++) {
float value =
-weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
loss[i] = value;
}
} else {
for (int i = 0; i < input_size; i++) {
float value =
-weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon));
tmp_loss[i] = value;
}
}
}
void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, float *loss, float *tmp_loss) {
loss[0] = 0.0f;
BinaryCrossEntropyLossKernel(input_size, reduction, input_x, input_y, weight, loss, tmp_loss);
if (reduction != 0) {
if (input_size % 2 == 1) {
tmp_loss[0] += tmp_loss[input_size - 1];
}
for (int stride = input_size / 2; stride > 0; stride >>= 1) {
for (int i = 0; i < stride; i++) {
tmp_loss[i] += tmp_loss[i + stride];
}
if (stride > 2 && stride % 2 == 1) {
tmp_loss[0] += tmp_loss[stride - 1];
}
}
loss[0] += tmp_loss[0];
if (reduction == 1) {
loss[0] /= input_size;
}
}
}

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_H_
#define MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_H_
#include "nnacl/op_base.h"
typedef struct BinaryCrossEntropyParameter {
OpParameter op_parameter_;
int reduction;
} BinaryCrossEntropyParameter;
#ifdef __cplusplus
extern "C" {
#endif
void BinaryCrossEntropy(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, float *loss, float *tmp_loss);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_H_

View File

@ -0,0 +1,42 @@
/*
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, const float *dloss, float *dx) {
float epsilon = 1e-12;
if (reduction == 0) {
for (int i = 0; i < input_size; i++) {
float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i];
}
} else {
float dloss1 = dloss[0];
if (reduction == 1) {
dloss1 = dloss[0] / input_size;
}
for (int i = 0; i < input_size; i++) {
float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);
float value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1;
}
}
return 0;
}

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_GRAD_H_
#define MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_GRAD_H_
#include "nnacl/op_base.h"
typedef struct BinaryCrossEntropyGradParameter {
OpParameter op_parameter_;
int reduction;
} BinaryCrossEntropyGradParameter;
#ifdef __cplusplus
extern "C" {
#endif
int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, const float *dloss, float *dx);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_BINARY_CROSS_ENTROPY_GRAD_H_

View File

@ -15,6 +15,11 @@
*/
#include "src/ops/assign_add.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -58,7 +63,13 @@ int AssignAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *AssignAddCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<AssignAdd>(primitive);
}
Registry AssignAddRegistry(schema::PrimitiveType_AssignAdd, AssignAddCreator);
#endif
int AssignAdd::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];
Tensor *y = inputs_[1];

View File

@ -17,6 +17,10 @@
#include <string>
#include "src/ops/binary_cross_entropy.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -85,6 +89,11 @@ int BinaryCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive,
}
int BinaryCrossEntropy::GetReduction() const { return this->primitive_->value_as_BinaryCrossEntropy()->reduction(); }
PrimitiveC *BinaryCrossEntropyCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BinaryCrossEntropy>(primitive);
}
Registry BinaryCrossEntropyRegistry(schema::PrimitiveType_BinaryCrossEntropy, BinaryCrossEntropyCreator);
#endif
int BinaryCrossEntropy::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];

View File

@ -17,6 +17,10 @@
#include <string>
#include "src/ops/binary_cross_entropy_grad.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -92,6 +96,11 @@ int BinaryCrossEntropyGrad::UnPackToFlatBuilder(const schema::Primitive *primiti
int BinaryCrossEntropyGrad::GetReduction() const {
return this->primitive_->value_as_BinaryCrossEntropyGrad()->reduction();
}
PrimitiveC *BinaryCrossEntropyGradCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BinaryCrossEntropyGrad>(primitive);
}
Registry BinaryCrossEntropyGradRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad, BinaryCrossEntropyGradCreator);
#endif
int BinaryCrossEntropyGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];

View File

@ -89,7 +89,7 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
if (inputs_.size() != kSingleNum || outputs_.size() != kSingleNum) {
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "tensor number is error.";
return RET_INPUT_TENSOR_ERROR;
}

View File

@ -54,8 +54,15 @@ int Gather::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
delete gather_attr;
return RET_ERROR;
}
gather_attr->axis = GetValue<int>(prim.GetAttr("axis"));
gather_attr->batchDims = GetValue<int>(prim.GetAttr("batchDims"));
if (inputs[2]->isa<ValueNode>()) {
ValueNodePtr axis_tensor = inputs[2]->cast<ValueNodePtr>();
int axis = GetValue<int>(axis_tensor->value());
gather_attr->axis = axis;
} else {
MS_LOG(ERROR) << "input axis is not value node.";
return RET_ERROR;
}
gather_attr->batchDims = 0;
this->primitive_->value.value = gather_attr;
}
return RET_OK;
@ -85,8 +92,7 @@ Registry GatherRegistry(schema::PrimitiveType_Gather, GatherCreator);
int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "Gather should have two inputs";
return RET_INPUT_TENSOR_ERROR;
MS_LOG(DEBUG) << "Gather should have two inputs";
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "Gather should have one outputs";

View File

@ -16,6 +16,10 @@
#include "src/ops/oneslike.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -59,6 +63,11 @@ int OnesLike::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *OnesLikeCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<OnesLike>(primitive);
}
Registry OnesLikeRegistry(schema::PrimitiveType_OnesLike, OnesLikeCreator);
#endif
int OnesLike::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
Tensor *x = inputs_[0];

View File

@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/activation_grad.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/activation_grad.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateActivationGradParameter(const mindspore::lite::PrimitiveC *primitive) {
ActivationGradParameter *act_param =
reinterpret_cast<ActivationGradParameter *>(malloc(sizeof(ActivationGradParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr;
}
memset(act_param, 0, sizeof(ActivationGradParameter));
act_param->op_parameter.type_ = primitive->Type();
auto activation =
reinterpret_cast<mindspore::lite::ActivationGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
act_param->type_ = static_cast<int>(activation->GetType());
act_param->alpha_ = activation->GetAlpha();
return reinterpret_cast<OpParameter *>(act_param);
}
Registry ActivationGradParameterRegistry(schema::PrimitiveType_ActivationGrad, PopulateActivationGradParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/adam.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateAdamParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc Adam Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}
Registry AdamParameterRegistry(schema::PrimitiveType_Adam, PopulateAdamParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/assign_add.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateAssignAddParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc AssignAdd Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}
Registry AssignAddParameterRegistry(schema::PrimitiveType_AssignAdd, PopulateAssignAddParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/assign.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateAssignParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc Assign Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}
Registry AssignParameterRegistry(schema::PrimitiveType_Assign, PopulateAssignParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/arithmetic_common.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateBiasGradParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry PopulateBiasGradParameterParameterRegistry(schema::PrimitiveType_BiasGrad, PopulateBiasGradParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/binary_cross_entropy_grad.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/binary_cross_entropy_grad.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateBinaryCrossEntropyGradParameter(const mindspore::lite::PrimitiveC *primitive) {
BinaryCrossEntropyGradParameter *bce_param =
reinterpret_cast<BinaryCrossEntropyGradParameter *>(malloc(sizeof(BinaryCrossEntropyGradParameter)));
if (bce_param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropyGrad Parameter failed.";
return nullptr;
}
memset(bce_param, 0, sizeof(BinaryCrossEntropyGradParameter));
bce_param->op_parameter_.type_ = primitive->Type();
auto param =
reinterpret_cast<mindspore::lite::BinaryCrossEntropyGrad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
bce_param->reduction = param->GetReduction();
return reinterpret_cast<OpParameter *>(bce_param);
}
Registry BinaryCrossEntropyGradParameterRegistry(schema::PrimitiveType_BinaryCrossEntropyGrad,
PopulateBinaryCrossEntropyGradParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,42 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/binary_cross_entropy.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32_grad/binary_cross_entropy.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateBinaryCrossEntropyParameter(const mindspore::lite::PrimitiveC *primitive) {
BinaryCrossEntropyParameter *bce_param =
reinterpret_cast<BinaryCrossEntropyParameter *>(malloc(sizeof(BinaryCrossEntropyParameter)));
if (bce_param == nullptr) {
MS_LOG(ERROR) << "malloc BinaryCrossEntropy Parameter failed.";
return nullptr;
}
memset(bce_param, 0, sizeof(BinaryCrossEntropyParameter));
bce_param->op_parameter_.type_ = primitive->Type();
auto param =
reinterpret_cast<mindspore::lite::BinaryCrossEntropy *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
bce_param->reduction = param->GetReduction();
return reinterpret_cast<OpParameter *>(bce_param);
}
Registry BinaryCrossEntropyParameterRegistry(schema::PrimitiveType_BinaryCrossEntropy,
PopulateBinaryCrossEntropyParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/oneslike.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateOnesLikeParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc OnesLike Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}
Registry OnesLikeParameterRegistry(schema::PrimitiveType_OnesLike, PopulateOnesLikeParameter);
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/unsorted_segment_sum.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateUnsortedSegmentSumParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc UnsortedSegmentSum Parameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = primitive->Type();
return param;
}
Registry UnsortedSegmentSumParameterRegistry(schema::PrimitiveType_UnsortedSegmentSum,
PopulateUnsortedSegmentSumParameter);
} // namespace lite
} // namespace mindspore

View File

@ -94,6 +94,12 @@ int Slice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inpu
}
}
}
std::vector<int> axes;
axes.clear();
for (size_t i = 0; i < attr->begin.size(); i++) {
axes.push_back(i);
}
attr->axes = axes;
}
this->primitive_->value.value = attr;
}

View File

@ -17,6 +17,10 @@
#include <memory>
#include "src/ops/unsorted_segment_sum.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -69,6 +73,11 @@ int UnsortedSegmentSum::GetNumSegments() const {
int ret = this->primitive_->value_as_UnsortedSegmentSum()->numSegments();
return ret;
}
PrimitiveC *UnsortedSegmentSumCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<UnsortedSegmentSum>(primitive);
}
Registry UnsortedSegmentSumRegistry(schema::PrimitiveType_UnsortedSegmentSum, UnsortedSegmentSumCreator);
#endif
int UnsortedSegmentSum::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
// check inputs and outputs