[MSLITE][Develop] support al_bert inferance
This commit is contained in:
parent
e8fb3dfcc9
commit
4e9002329f
|
@ -103,3 +103,4 @@ int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, T
|
|||
|
||||
REG_INFER(DivGrad, PrimType_DivGrad, ArithmeticGradInferShape)
|
||||
REG_INFER(MulGrad, PrimType_MulGrad, ArithmeticGradInferShape)
|
||||
REG_INFER(MinimumGrad, PrimType_MinimumGrad, ArithmeticGradInferShape)
|
||||
|
|
|
@ -37,6 +37,8 @@ int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
if (!param->op_parameter_.infer_flag_) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
param->begin_norm_axis_ =
|
||||
param->begin_norm_axis_ < 0 ? param->begin_norm_axis_ + input->shape_size_ : param->begin_norm_axis_;
|
||||
SetShapeTensor(output, input);
|
||||
// take care of other outputs
|
||||
if (outputs_size == 3) {
|
||||
|
@ -45,10 +47,9 @@ int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
SetDataTypeFormat(output_mean, input);
|
||||
SetDataTypeFormat(output_var, input);
|
||||
int size = 0;
|
||||
for (int i = param->begin_norm_axis_; i < input->shape_size_; i++) {
|
||||
output_mean->shape_[size] = input->shape_[i];
|
||||
output_var->shape_[size] = input->shape_[i];
|
||||
size++;
|
||||
for (; size < param->begin_norm_axis_; size++) {
|
||||
output_mean->shape_[size] = input->shape_[size];
|
||||
output_var->shape_[size] = input->shape_[size];
|
||||
}
|
||||
output_mean->shape_size_ = size;
|
||||
output_var->shape_size_ = size;
|
||||
|
|
|
@ -35,8 +35,8 @@ using mindspore::schema::PrimitiveType_ActivationGrad;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
int ActivationGradCPUKernel::Init() {
|
||||
if (in_tensors_.size() != 2) {
|
||||
MS_LOG(ERROR) << "ActivationGrad should have 2 input tensors";
|
||||
if (in_tensors_.size() < 2) {
|
||||
MS_LOG(ERROR) << "ActivationGrad should have more than 2 input tensors";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
Loading…
Reference in New Issue