forked from mindspore-Ecosystem/mindspore
[feat][assistant][I48OCS]add new dynamic shape operator
This commit is contained in:
parent
90cbad6baf
commit
e9c9088984
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/apply_ftrl.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr ApplyFtrlInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t kInputNum = 8;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
auto var_shape = input_args[kInputIndex0]->BuildShape();
|
||||
auto accum_shape = input_args[kInputIndex1]->BuildShape();
|
||||
auto linear_shape = input_args[kInputIndex2]->BuildShape();
|
||||
if (var_shape->IsDynamic() || accum_shape->IsDynamic() || linear_shape->IsDynamic()) {
|
||||
return var_shape->cast<abstract::ShapePtr>();
|
||||
}
|
||||
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
|
||||
same_shape_args_map.insert({"accum", accum_shape});
|
||||
same_shape_args_map.insert({"linear", linear_shape});
|
||||
for (auto &elem : same_shape_args_map) {
|
||||
if (*elem.second != *var_shape) {
|
||||
MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString()
|
||||
<< " are not consistent with var shape " << var_shape->ToString();
|
||||
}
|
||||
}
|
||||
auto shape_ptr = var_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
return shape_ptr;
|
||||
}
|
||||
TypePtr ApplyFtrlInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto prim_name = prim->name();
|
||||
const int64_t kInputNum = 8;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
auto var_type = input_args[kInputIndex0]->BuildType();
|
||||
auto accum_type = input_args[kInputIndex1]->BuildType();
|
||||
auto linear_type = input_args[kInputIndex2]->BuildType();
|
||||
auto grad_type = input_args[kInputIndex3]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.insert({"var_type", var_type});
|
||||
(void)args.insert({"accum_type", accum_type});
|
||||
(void)args.insert({"linear_type", linear_type});
|
||||
(void)args.insert({"grad_type", grad_type});
|
||||
// var accum linear grad must have same dtypes
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
|
||||
auto lr_type = input_args[kInputIndex4]->BuildType();
|
||||
auto l1_type = input_args[kInputIndex5]->BuildType();
|
||||
auto l2_type = input_args[kInputIndex6]->BuildType();
|
||||
auto lr_power_type = input_args[kInputIndex7]->BuildType();
|
||||
std::map<std::string, TypePtr> args_lr;
|
||||
std::map<std::string, TypePtr> args_l1;
|
||||
std::map<std::string, TypePtr> args_l2;
|
||||
std::map<std::string, TypePtr> args_lr_power;
|
||||
(void)args_lr.insert({"lr_type", lr_type});
|
||||
(void)args_l1.insert({"l1_type", l1_type});
|
||||
(void)args_l2.insert({"l2_type", l2_type});
|
||||
(void)args_lr_power.insert({"lr_power_type", lr_power_type});
|
||||
|
||||
// lr, l1, l2, lr_power type must be float or scalar tensor with float
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l1, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l2, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr_power, valid_types, prim_name);
|
||||
|
||||
return var_type;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t kInputNum = 8;
|
||||
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
|
||||
auto infer_type = ApplyFtrlInferType(primitive, input_args);
|
||||
auto infer_shape = ApplyFtrlInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyFtrl, prim::kPrimApplyFtrl, ApplyFtrlInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_APPLY_FTRL_H_
|
||||
#define MINDSPORE_CORE_OPS_APPLY_FTRL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameApplyFtrl = "ApplyFtrl";
|
||||
class ApplyFtrl : public PrimitiveC {
|
||||
public:
|
||||
ApplyFtrl() : PrimitiveC(kNameApplyFtrl) {
|
||||
InitIOName({"var", "accum", "linear", "grad", "lr", "l1", "l2", "lr_power"}, {"var"});
|
||||
}
|
||||
|
||||
~ApplyFtrl() = default;
|
||||
MS_DECLARE_PARENT(ApplyFtrl, PrimitiveC);
|
||||
};
|
||||
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using kPrimApplyFtrlPtr = std::shared_ptr<ApplyFtrl>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_APPLY_FTRL_H_
|
|
@ -38,6 +38,7 @@ from .add_n_ds import _add_n_ds_tbe
|
|||
from .accumulate_n_v2 import _accumulate_n_v2_tbe
|
||||
from .accumulate_n_v2_ds import _accumulate_n_v2_ds_tbe
|
||||
from .apply_ftrl import _apply_ftrl_tbe
|
||||
from .apply_ftrl_ds import _apply_ftrl_ds_tbe
|
||||
from .apply_keras_momentum import _apply_keras_momentum_tbe
|
||||
from .apply_momentum import _apply_momentum_tbe
|
||||
from .apply_momentum_ds import _apply_momentum_ds_tbe
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ApplyFtrl op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
apply_ftrl_ds_op_info = TBERegOp("ApplyFtrl") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("apply_ftrl.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("apply_ftrl_d") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "var", False, "required", "all") \
|
||||
.input(1, "accum", False, "required", "all") \
|
||||
.input(2, "linear", False, "required", "all") \
|
||||
.input(3, "grad", False, "required", "all") \
|
||||
.input(4, "lr", False, "required", "all") \
|
||||
.input(5, "l1", False, "required", "all") \
|
||||
.input(6, "l2", False, "required", "all") \
|
||||
.input(7, "lr_power", False, "required", "all") \
|
||||
.output(0, "var", False, "required", "all") \
|
||||
.output(1, "accum", False, "required", "all") \
|
||||
.output(2, "linear", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
|
||||
DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
|
||||
DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
|
||||
DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
|
||||
DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(apply_ftrl_ds_op_info)
|
||||
def _apply_ftrl_ds_tbe():
|
||||
"""Applyftrl TBE register"""
|
||||
return
|
|
@ -6569,7 +6569,7 @@ class LARSUpdate(PrimitiveWithInfer):
|
|||
validator.check_value_type("use_clip", use_clip, [bool], self.name)
|
||||
|
||||
|
||||
class ApplyFtrl(PrimitiveWithInfer):
|
||||
class ApplyFtrl(Primitive):
|
||||
"""
|
||||
Updates relevant entries according to the FTRL scheme.
|
||||
|
||||
|
@ -6644,24 +6644,6 @@ class ApplyFtrl(PrimitiveWithInfer):
|
|||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
|
||||
|
||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
|
||||
lr_power_shape):
|
||||
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
|
||||
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
|
||||
return var_shape
|
||||
|
||||
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type,
|
||||
lr_power_type):
|
||||
valid_dtypes = [mstype.float16, mstype.float32]
|
||||
args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
|
||||
validator.check_scalar_or_tensor_types_same({"lr": lr_type}, valid_dtypes, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name)
|
||||
validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name)
|
||||
return var_type
|
||||
|
||||
|
||||
class SparseApplyFtrl(PrimitiveWithCheck):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue