forked from mindspore-Ecosystem/mindspore
!25059 dynamic shape ops fix
Merge pull request !25059 from wangnan39/dynamic_shape_ops_fix
This commit is contained in:
commit
6c9bd5b3cb
|
@ -20,14 +20,30 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const size_t kInputNum = 2;
|
||||
const int kInputNum = 2;
|
||||
const size_t one = 1;
|
||||
|
||||
void UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one) {
|
||||
for (size_t i = 0; i < kInputNum; ++i) {
|
||||
(*prev_is_one)[i] = current_is_one[i];
|
||||
}
|
||||
}
|
||||
void AddElementToGradReduceIdx(std::vector<std::vector<int64_t>> *grad_reduce_idx, std::vector<bool> current_is_one,
|
||||
bool none_is_one, const size_t largest_rank, size_t j) {
|
||||
MS_EXCEPTION_IF_NULL(grad_reduce_idx);
|
||||
for (size_t i = 0; i < kInputNum; ++i) {
|
||||
if (current_is_one[i] && !none_is_one) {
|
||||
(void)(*grad_reduce_idx)[i].emplace_back(SizeToLong(largest_rank - one - j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetGradientIndices(const std::vector<std::vector<int64_t>> &reverse_shape,
|
||||
const size_t largest_rank) {
|
||||
std::vector<std::vector<int64_t>> grad_reduce_idx(kInputNum);
|
||||
// indices of j-th component of each input.
|
||||
bool prev_is_one[kInputNum];
|
||||
bool current_is_one[kInputNum];
|
||||
std::vector<bool> prev_is_one(kInputNum);
|
||||
std::vector<bool> current_is_one(kInputNum);
|
||||
for (size_t i = 0; i < kInputNum; ++i) {
|
||||
prev_is_one[i] = false;
|
||||
current_is_one[i] = false;
|
||||
|
@ -46,37 +62,26 @@ std::vector<std::vector<int64_t>> GetGradientIndices(const std::vector<std::vect
|
|||
} else {
|
||||
current_is_one[i] = false;
|
||||
if (!output_dim_set || reverse_shape[i][j] == static_cast<int64_t>(output_dim)) {
|
||||
output_dim = static_cast<int>(reverse_shape[i][j]);
|
||||
output_dim = LongToInt(reverse_shape[i][j]);
|
||||
output_dim_set = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Input[0] and input[1] Cannot broadcast!";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All dimensions are 1.
|
||||
if (!output_dim_set) {
|
||||
for (size_t i = 0; i < kInputNum; ++i) {
|
||||
(void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j);
|
||||
(void)grad_reduce_idx[i].emplace_back(SizeToLong(largest_rank - one - j));
|
||||
}
|
||||
continue;
|
||||
} else if (std::equal(current_is_one, current_is_one + kInputNum, prev_is_one) && set_one) {
|
||||
for (size_t i = 0; i < kInputNum; ++i) {
|
||||
if (current_is_one[i] && !none_is_one) {
|
||||
(void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j);
|
||||
}
|
||||
}
|
||||
} else if (std::equal(current_is_one.begin(), current_is_one.end(), prev_is_one.begin()) && set_one) {
|
||||
AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j);
|
||||
} else {
|
||||
for (size_t i = 0; i < kInputNum; ++i) {
|
||||
if (current_is_one[i] && !none_is_one) {
|
||||
(void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j);
|
||||
}
|
||||
}
|
||||
AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j);
|
||||
}
|
||||
set_one = true;
|
||||
for (size_t i = 0; i < kInputNum; ++i) {
|
||||
prev_is_one[i] = current_is_one[i];
|
||||
}
|
||||
UpdatePreIsOne(&prev_is_one, current_is_one);
|
||||
}
|
||||
return grad_reduce_idx;
|
||||
}
|
||||
|
@ -172,9 +177,10 @@ size_t SetOutputValue(const CNodePtr &cnode, const std::vector<std::vector<int64
|
|||
*(data_ptr + i) = output[i];
|
||||
}
|
||||
|
||||
(void)out_addr->SyncHostToDevice(out_shape, LongToSize(tensor_for_sync->data().nbytes()),
|
||||
tensor_for_sync->data_type(), tensor_for_sync->data_c(),
|
||||
tensor_for_sync->device_info().host_format_);
|
||||
if (!out_addr->SyncHostToDevice(out_shape, LongToSize(tensor_for_sync->data().nbytes()), tensor_for_sync->data_type(),
|
||||
tensor_for_sync->data_c(), tensor_for_sync->device_info().host_format_)) {
|
||||
MS_LOG(EXCEPTION) << "Output Value SyncHostToDevice failed.";
|
||||
}
|
||||
return out_size;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace kernel {
|
|||
|
||||
void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *input_size_list, size_t *size_i) {
|
||||
if (input_json[kJShape].size() == 1 && input_json[kJShape][0] == -2) {
|
||||
auto input_max_shape = input_json[kJShape];
|
||||
auto input_max_shape = input_json[kJRange];
|
||||
for (auto &max_shape : input_max_shape) {
|
||||
(*size_i) *= LongToSize(max_shape[1]);
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *inp
|
|||
|
||||
void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *output_size_list, size_t *size_i) {
|
||||
if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) {
|
||||
auto output_max_shape = output_json[kJShape];
|
||||
auto output_max_shape = output_json[kJRange];
|
||||
for (auto &max_shape : output_max_shape) {
|
||||
(*size_i) *= LongToSize(max_shape[1]);
|
||||
}
|
||||
|
|
|
@ -122,10 +122,9 @@ bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr<kernel::Kernel
|
|||
void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(cast);
|
||||
using Shape = std::vector<size_t>;
|
||||
auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0);
|
||||
auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0);
|
||||
std::vector<Shape> shapes;
|
||||
auto cast_shape = AnfAlgo::GetOutputDetailShape(cast, 0);
|
||||
std::vector<abstract::BaseShapePtr> shapes;
|
||||
std::vector<TypeId> types;
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
||||
for (size_t index = 0; index < output_num; ++index) {
|
||||
|
@ -134,10 +133,10 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
|
|||
(void)types.emplace_back(cast_dtype);
|
||||
continue;
|
||||
}
|
||||
(void)shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index));
|
||||
(void)shapes.emplace_back(AnfAlgo::GetOutputDetailShape(cnode, index));
|
||||
(void)types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index));
|
||||
}
|
||||
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get());
|
||||
AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());
|
||||
auto prim_op = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
if (prim_op != nullptr) {
|
||||
(void)prim_op->AddAttr("cast_type", TypeIdToType(cast_dtype));
|
||||
|
|
|
@ -48,6 +48,7 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type)
|
|||
{"Softmax", "SoftmaxV2"},
|
||||
{"DropoutDoMask", "DropOutDoMask"},
|
||||
{"IOU", "Iou"},
|
||||
{"DynamicBroadcastTo", "BroadcastTo"},
|
||||
};
|
||||
auto iter = kOpTypeMap.find(op_type);
|
||||
if (iter == kOpTypeMap.end()) {
|
||||
|
|
|
@ -110,18 +110,34 @@ TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &acce
|
|||
return CheckType(type, accepts, error_message_prefix);
|
||||
}
|
||||
|
||||
ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
|
||||
void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
|
||||
MS_EXCEPTION_IF_NULL(tensor_base);
|
||||
ShapePtr shape_base = tensor_base->shape();
|
||||
MS_EXCEPTION_IF_NULL(shape_base);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
ShapePtr shape = tensor->shape();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (*shape != *shape_base) {
|
||||
if (shape_base->IsDimUnknown() || shape->IsDimUnknown()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto shape_vector = shape->shape();
|
||||
auto shape_base_vector = shape_base->shape();
|
||||
if (shape_vector.size() != shape_base_vector.size()) {
|
||||
MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << shape->ToString()
|
||||
<< " are not consistent with second arg shape " << shape_base->ToString();
|
||||
}
|
||||
return shape_base;
|
||||
|
||||
for (size_t i = 0; i < shape_vector.size(); i++) {
|
||||
if (shape_vector[i] == Shape::SHP_ANY || shape_base_vector[i] == Shape::SHP_ANY) {
|
||||
continue;
|
||||
}
|
||||
if (shape_vector[i] != shape_base_vector[i]) {
|
||||
MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << shape->ToString()
|
||||
<< " are not consistent with second arg shape " << shape_base->ToString();
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
|
||||
|
|
|
@ -41,7 +41,7 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty
|
|||
TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts,
|
||||
const std::string &error_message_prefix);
|
||||
|
||||
ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor);
|
||||
void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor);
|
||||
|
||||
TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor);
|
||||
|
||||
|
|
|
@ -114,7 +114,7 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
for (size_t i = 1; i < tuple_len; ++i) {
|
||||
AbstractTensorPtr tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i);
|
||||
(void)CheckDtypeSame(op_name, tensor_base, tensor);
|
||||
(void)CheckShapeSame(op_name, tensor_base, tensor);
|
||||
CheckShapeSame(op_name, tensor_base, tensor);
|
||||
}
|
||||
auto element = tensor_base->element();
|
||||
MS_EXCEPTION_IF_NULL(element);
|
||||
|
@ -1241,6 +1241,7 @@ AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const Primitive
|
|||
AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
bool output_shape_unknow = false;
|
||||
auto prim_name = primitive->name();
|
||||
constexpr int64_t args_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(args_spec_list.size()), kEqual, args_size,
|
||||
|
@ -1253,6 +1254,22 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
|
|||
auto input_tuple = args_spec_list[0]->cast<abstract::AbstractSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_tuple);
|
||||
auto indices = input_tuple->elements();
|
||||
auto input_indice_size = input_tuple->size();
|
||||
int64_t first_dim_size = 0;
|
||||
for (size_t i = 0; i < input_indice_size; i++) {
|
||||
auto indicei = indices[i]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(indicei);
|
||||
auto valuei = indicei->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(valuei);
|
||||
if (!valuei->isa<tensor::Tensor>()) {
|
||||
output_shape_unknow = true;
|
||||
continue;
|
||||
}
|
||||
auto indicei_value = CheckAndConvertUtils::CheckTensorIntValue("indices", valuei, prim_name);
|
||||
auto indicei_max = std::max_element(indicei_value.begin(), indicei_value.end());
|
||||
first_dim_size = *indicei_max > first_dim_size ? *indicei_max : first_dim_size;
|
||||
}
|
||||
|
||||
auto indices0 = indices[0]->cast<abstract::AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(indices0);
|
||||
auto indices0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices0->BuildShape())[kShape];
|
||||
|
@ -1282,7 +1299,12 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
|
|||
std::set<TypePtr> valid_types = ops::common_valid_types;
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
|
||||
ShapeVector out_shape = {abstract::Shape::SHP_ANY};
|
||||
ShapeVector out_shape;
|
||||
if (output_shape_unknow) {
|
||||
out_shape.push_back(abstract::Shape::SHP_ANY);
|
||||
} else {
|
||||
out_shape.push_back(first_dim_size + 1);
|
||||
}
|
||||
for (size_t i = indices0_shape.size(); i < data0_shape.size(); ++i) {
|
||||
out_shape.push_back(data0_shape[i]);
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
(void)CheckDtypeSame(op_name, out, dout);
|
||||
(void)CheckShapeSame(op_name, out, dout);
|
||||
CheckShapeSame(op_name, out, dout);
|
||||
|
||||
return out->Broaden();
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
|||
const auto kStridedSlice = prim::kPrimStridedSlice->name();
|
||||
const auto kStridedSliceGrad = prim::kPrimStridedSliceGrad->name();
|
||||
const auto kReduceSum = prim::kPrimReduceSum->name();
|
||||
const auto kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name();
|
||||
const auto kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
|
||||
const auto kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
|
||||
const auto kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
|
||||
|
@ -78,7 +79,8 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
|
|||
{kTile, {1}},
|
||||
{kReshape, {1}},
|
||||
{kSlice, {1, 2}},
|
||||
{kSliceGrad, {2, 3}}};
|
||||
{kSliceGrad, {2, 3}},
|
||||
{kDynamicBroadcastTo, {1}}};
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
|
|
|
@ -95,6 +95,7 @@ constexpr auto kDiagPart = "DiagPart";
|
|||
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
|
||||
constexpr auto kTranspose = "Transpose";
|
||||
constexpr auto kSplitV = "SplitV";
|
||||
constexpr auto kDynamicBroadcastTo = "DynamicBroadcastTo";
|
||||
|
||||
// NN
|
||||
constexpr auto kCTCLoss = "CTCLoss";
|
||||
|
@ -170,6 +171,7 @@ inline const PrimitivePtr kPrimStackPush = std::make_shared<Primitive>("StackPus
|
|||
inline const PrimitivePtr kPrimStackPop = std::make_shared<Primitive>("StackPop");
|
||||
|
||||
// Arrays
|
||||
inline const PrimitivePtr kPrimDynamicBroadcastTo = std::make_shared<Primitive>(kDynamicBroadcastTo);
|
||||
inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo");
|
||||
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
|
||||
inline const PrimitivePtr kPrimTopK = std::make_shared<Primitive>("TopK");
|
||||
|
|
|
@ -73,8 +73,8 @@ abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
|
||||
TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t x_index = 0;
|
||||
return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim->name());
|
||||
const size_t x_index = 0;
|
||||
return CheckAndConvertUtils::GetTensorInputType(prim->name(), input_args, x_index);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -115,7 +115,20 @@ TypePtr BatchMatmulInferType(const PrimitivePtr &prim, const std::vector<Abstrac
|
|||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("x", input_args[0]->BuildType());
|
||||
(void)types.emplace("w", input_args[1]->BuildType());
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
TypePtr x_type = input_args[0]->BuildType();
|
||||
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
|
||||
x_type = kInt32;
|
||||
}
|
||||
if (prim->HasAttr("cast_type")) {
|
||||
auto out_type = prim->GetAttr("cast_type");
|
||||
MS_EXCEPTION_IF_NULL(out_type);
|
||||
if (!out_type->isa<Type>()) {
|
||||
MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
|
||||
}
|
||||
x_type = out_type->cast<TypePtr>();
|
||||
}
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -48,17 +48,19 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto x_shape_vector = x_shape->shape();
|
||||
auto mask_shape_vector = mask_shape->shape();
|
||||
|
||||
int64_t x_size = 1;
|
||||
for (size_t i = 0; i < x_shape_vector.size(); i++) {
|
||||
x_size *= x_shape_vector[i];
|
||||
}
|
||||
if (mask_shape_vector.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask must be 1-dimension.";
|
||||
}
|
||||
auto mask_size = mask_shape_vector[0] * 8;
|
||||
if (x_size > mask_size) {
|
||||
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask do not match input, input_x shape: " << x_shape->ToString()
|
||||
<< ", mask shape: " << mask_shape->ToString();
|
||||
if (!x_shape->IsDynamic() && !mask_shape->IsDynamic()) {
|
||||
int64_t x_size = 1;
|
||||
for (size_t i = 0; i < x_shape_vector.size(); i++) {
|
||||
x_size *= x_shape_vector[i];
|
||||
}
|
||||
if (mask_shape_vector.size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask must be 1-dimension.";
|
||||
}
|
||||
auto mask_size = mask_shape_vector[0] * 8;
|
||||
if (x_size > mask_size) {
|
||||
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask do not match input, input_x shape: " << x_shape->ToString()
|
||||
<< ", mask shape: " << mask_shape->ToString();
|
||||
}
|
||||
}
|
||||
auto keep_prop = input_args[kInputIndex2];
|
||||
if (keep_prop->isa<abstract::AbstractTensor>()) {
|
||||
|
|
|
@ -86,7 +86,6 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) {
|
|||
}
|
||||
count = count * value;
|
||||
}
|
||||
|
||||
// convert to bytes(8 bits) mask, using round up
|
||||
int64_t n128s = count / mask_convert_len;
|
||||
if ((count % mask_convert_len) != 0) {
|
||||
|
@ -106,7 +105,18 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
AbstractBasePtr shape_args = input_args[0];
|
||||
MS_EXCEPTION_IF_NULL(shape_args);
|
||||
|
||||
ShapeVector out_shape;
|
||||
if (shape_args->isa<abstract::AbstractTensor>()) {
|
||||
auto shape_value = shape_args->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(shape_value);
|
||||
if (shape_value->isa<tensor::Tensor>()) {
|
||||
auto mask_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", shape_value, op_name);
|
||||
std::vector<ValuePtr> value_elements;
|
||||
std::transform(mask_shape.begin(), mask_shape.end(), std::back_inserter(value_elements),
|
||||
[](int64_t elem) { return MakeValue(elem); });
|
||||
out_shape = CalDynamicOutputShape(value_elements);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
auto shape_abstract = dyn_cast<abstract::AbstractTensor>(shape_args);
|
||||
MS_EXCEPTION_IF_NULL(shape_abstract);
|
||||
auto shape_base = shape_abstract->BuildShape();
|
||||
|
@ -139,7 +149,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
auto x_shape = dyn_cast<abstract::AbstractTuple>(shape_args);
|
||||
auto x_shape_data = x_shape->elements();
|
||||
ShapeVector out_shape = CalOutputShape(x_shape_data);
|
||||
out_shape = CalOutputShape(x_shape_data);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/dynamic_broadcast_to.h"
|
||||
|
||||
#include <set>
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
|
||||
auto input_y = input_args[1];
|
||||
MS_EXCEPTION_IF_NULL(input_y);
|
||||
abstract::ShapePtr y_shape;
|
||||
auto y_value = input_y->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(y_value);
|
||||
if (input_y->isa<abstract::AbstractTensor>()) {
|
||||
if (y_value->isa<tensor::Tensor>()) {
|
||||
auto shape_value = CheckAndConvertUtils::CheckTensorIntValue("shape", y_value, prim_name);
|
||||
return std::make_shared<abstract::Shape>(shape_value);
|
||||
}
|
||||
y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
|
||||
auto shape_value = y_shape->shape();
|
||||
if (shape_value.size() != 1) {
|
||||
MS_EXCEPTION(TypeError) << "shape size error: " << shape_value.size();
|
||||
}
|
||||
std::vector<int64_t> output_shape;
|
||||
std::vector<int64_t> max_shape;
|
||||
std::vector<int64_t> min_shape;
|
||||
if (y_shape->IsDynamic()) {
|
||||
// max shape unknown
|
||||
output_shape.push_back(-2);
|
||||
} else {
|
||||
auto out_dims = LongToSize(y_shape->shape()[0]);
|
||||
for (size_t i = 0; i < out_dims; i++) {
|
||||
output_shape.push_back(-1);
|
||||
}
|
||||
auto min_value = input_y->cast<abstract::AbstractTensorPtr>()->get_min_value();
|
||||
auto max_value = input_y->cast<abstract::AbstractTensorPtr>()->get_max_value();
|
||||
if (!min_value || !max_value) {
|
||||
MS_EXCEPTION(ValueError) << "For BroadcastTo, inputs['shape'] min or max value is empty.";
|
||||
}
|
||||
min_shape = GetValue<std::vector<int64_t>>(min_value);
|
||||
max_shape = GetValue<std::vector<int64_t>>(max_value);
|
||||
if (min_shape.size() != out_dims || max_shape.size() != out_dims) {
|
||||
MS_EXCEPTION(ValueError) << "For BroadcastTo, inputs['shape'] min or max value not match with out dims.";
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
|
||||
} else if (input_y->isa<abstract::AbstractTuple>()) {
|
||||
auto out_shape = GetValue<std::vector<int64_t>>(y_value);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << "For BroadcastTo, input args must be tensor or tuple.";
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>();
|
||||
std::set<TypePtr> template_types = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
|
||||
return x_dtype->element();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr DynamicBroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicBroadcastTo, prim::kPrimDynamicBroadcastTo, DynamicBroadcastToInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_
|
||||
#define MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class DynamicBroadcastTo : public PrimitiveC {
|
||||
public:
|
||||
DynamicBroadcastTo() : PrimitiveC(prim::kPrimDynamicBroadcastTo->name()) { InitIOName({"x", "shape"}, {"y"}); }
|
||||
~DynamicBroadcastTo() = default;
|
||||
MS_DECLARE_PARENT(DynamicBroadcastTo, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr DynamicBroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimDynamicBroadcastToPtr = std::shared_ptr<DynamicBroadcastTo>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_
|
|
@ -48,8 +48,8 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
|
|||
(void)out_shape.insert(out_shape.begin() + dim_val, 1, 1);
|
||||
|
||||
// Infer type
|
||||
const int64_t x_index = 0;
|
||||
auto x_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim_name);
|
||||
const size_t x_index = 0;
|
||||
auto x_type = CheckAndConvertUtils::GetTensorInputType(prim_name, input_args, x_index);
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
(void)CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name);
|
||||
return std::make_shared<abstract::AbstractTensor>(x_type, out_shape);
|
||||
|
|
|
@ -24,9 +24,9 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int64_t kDoutIndex = 0;
|
||||
constexpr int64_t kInputIndex = 1;
|
||||
constexpr int64_t kFilterSizeIdex = 2;
|
||||
constexpr size_t kDoutIndex = 0;
|
||||
constexpr size_t kInputIndex = 1;
|
||||
constexpr size_t kFilterSizeIdex = 2;
|
||||
constexpr size_t kStride2dSize = 2;
|
||||
constexpr size_t kStride4dSize = 4;
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace ops {
|
|||
namespace {
|
||||
constexpr size_t kDoutIndex = 0;
|
||||
constexpr size_t kInputIndex = 1;
|
||||
constexpr int64_t kSizeIndex = 2;
|
||||
constexpr size_t kSizeIndex = 2;
|
||||
|
||||
void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
|
||||
const std::vector<int64_t> &x_size_v) {
|
||||
|
|
|
@ -44,8 +44,8 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
|
||||
|
||||
const int64_t dy_index = 0;
|
||||
const int64_t mask_index = 1;
|
||||
const size_t dy_index = 0;
|
||||
const size_t mask_index = 1;
|
||||
auto dy_type = input_args[dy_index]->BuildType();
|
||||
auto mask_type = input_args[mask_index]->BuildType();
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
(void)abstract::CheckShapeSame(prim_name, out, dout);
|
||||
abstract::CheckShapeSame(prim_name, out, dout);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
|
|
|
@ -124,7 +124,7 @@ AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const
|
|||
StridedSliceGradInferType(primitive, input_args));
|
||||
}
|
||||
|
||||
void StridedSliceGrad::set_begin_mask(const int64_t begin_mask) {
|
||||
void StridedSliceGrad::set_begin_mask(int64_t begin_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
|
||||
(void)this->AddAttr(kBeginMask, MakeValue(begin_mask));
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ int64_t StridedSliceGrad::get_begin_mask() const {
|
|||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::set_end_mask(const int64_t end_mask) {
|
||||
void StridedSliceGrad::set_end_mask(int64_t end_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
|
||||
(void)this->AddAttr(kEndMask, MakeValue(end_mask));
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ int64_t StridedSliceGrad::get_end_mask() const {
|
|||
auto value_ptr = GetAttr(kEndMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::set_ellipsis_mask(const int64_t ellipsis_mask) {
|
||||
void StridedSliceGrad::set_ellipsis_mask(int64_t ellipsis_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
|
||||
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
|
||||
std::ostringstream buffer;
|
||||
|
@ -155,7 +155,7 @@ int64_t StridedSliceGrad::get_ellipsis_mask() const {
|
|||
auto value_ptr = GetAttr(kEllipsisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::set_new_axis_mask(const int64_t new_axis_mask) {
|
||||
void StridedSliceGrad::set_new_axis_mask(int64_t new_axis_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
|
||||
(void)this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
|
||||
}
|
||||
|
@ -163,7 +163,7 @@ int64_t StridedSliceGrad::get_new_axis_mask() const {
|
|||
auto value_ptr = GetAttr(kNewAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
|
||||
void StridedSliceGrad::set_shrink_axis_mask(int64_t shrink_axis_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
|
||||
(void)this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
|
||||
}
|
||||
|
@ -171,8 +171,8 @@ int64_t StridedSliceGrad::get_shrink_axis_mask() const {
|
|||
auto value_ptr = GetAttr(kShrinkAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSliceGrad::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
|
||||
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
|
||||
void StridedSliceGrad::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
|
||||
int64_t shrink_axis_mask) {
|
||||
this->set_begin_mask(begin_mask);
|
||||
this->set_end_mask(end_mask);
|
||||
this->set_ellipsis_mask(ellipsis_mask);
|
||||
|
|
|
@ -34,13 +34,13 @@ class MS_CORE_API StridedSliceGrad : public PrimitiveC {
|
|||
|
||||
~StridedSliceGrad() = default;
|
||||
MS_DECLARE_PARENT(StridedSliceGrad, PrimitiveC);
|
||||
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
|
||||
const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0);
|
||||
void set_begin_mask(const int64_t begin_mask);
|
||||
void set_end_mask(const int64_t end_mask);
|
||||
void set_ellipsis_mask(const int64_t ellipsis_mask);
|
||||
void set_new_axis_mask(const int64_t new_axis_mask);
|
||||
void set_shrink_axis_mask(const int64_t shrink_axis_mask);
|
||||
void Init(int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0,
|
||||
int64_t shrink_axis_mask = 0);
|
||||
void set_begin_mask(int64_t begin_mask);
|
||||
void set_end_mask(int64_t end_mask);
|
||||
void set_ellipsis_mask(int64_t ellipsis_mask);
|
||||
void set_new_axis_mask(int64_t new_axis_mask);
|
||||
void set_shrink_axis_mask(int64_t shrink_axis_mask);
|
||||
int64_t get_begin_mask() const;
|
||||
int64_t get_end_mask() const;
|
||||
int64_t get_ellipsis_mask() const;
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace ops {
|
|||
namespace {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("infer_shape", int64_t(input_args.size()), kGreaterEqual, 1, op_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 1, op_name);
|
||||
return CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
|
||||
}
|
||||
|
||||
|
|
|
@ -33,9 +33,9 @@ AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const P
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 1;
|
||||
const int64_t x_index = 0;
|
||||
const size_t x_index = 0;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
|
||||
auto input_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name());
|
||||
auto input_type = CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, x_index);
|
||||
auto dst_type = TypeIdToType(TypeId(GetValue<int64_t>(primitive->GetAttr(kDstT))));
|
||||
MS_EXCEPTION_IF_NULL(dst_type);
|
||||
if (input_type != dst_type) {
|
||||
|
|
|
@ -48,7 +48,7 @@ void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape,
|
|||
ValuePtrList::iterator it;
|
||||
if (keep_dims_value) {
|
||||
for (it = axis_items.begin(); it != axis_items.end(); ++it) {
|
||||
auto axis_value = GetValue<int64_t>(*it);
|
||||
auto axis_value = InferImplReduceFuncCheckAxis(GetValue<int64_t>(*it), x_shape.size());
|
||||
shape->at(LongToSize(axis_value)) = 1;
|
||||
}
|
||||
} else {
|
||||
|
@ -108,10 +108,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
auto axis_shape = axis_tensor->shape()->shape();
|
||||
if (axis_shape.size() == 1 && axis_shape[0] == -1 && !keep_dims) {
|
||||
out_shape.push_back(-2);
|
||||
for (size_t i = 0; i < input_shape.size(); ++i) {
|
||||
out_min_shape.push_back(1);
|
||||
out_max_shape.push_back(max_v);
|
||||
}
|
||||
out_min_shape = input_min_shape;
|
||||
out_max_shape = input_max_shape;
|
||||
} else if (!keep_dims) {
|
||||
for (size_t i = 0; i < input_shape.size() - axis_shape.size(); ++i) {
|
||||
out_shape.push_back(-1);
|
||||
|
@ -136,7 +134,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
}
|
||||
MS_EXCEPTION_IF_NULL(axis_ptr);
|
||||
if (axis_ptr->isa<tensor::Tensor>()) {
|
||||
MS_LOG(ERROR) << "Tensor with value";
|
||||
auto axis_type = input_args[1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(axis_type);
|
||||
auto axis_type_id = axis_type->cast<TensorTypePtr>();
|
||||
|
@ -178,8 +175,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
|
|||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x dtype", input_args[0]->BuildType(), common_valid_types,
|
||||
"ReduceSum");
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x dtype", x_type, common_valid_types, prim->name());
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -321,7 +321,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t x_index = 0;
|
||||
const size_t x_index = 0;
|
||||
auto x_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, x_index);
|
||||
if (x_shape->IsDynamic()) {
|
||||
MS_EXCEPTION(ValueError) << "input x dynamic shape is currently not supported.";
|
||||
|
@ -363,12 +363,12 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
|||
}
|
||||
|
||||
TypePtr StridedSliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const int64_t x_index = 0;
|
||||
return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name());
|
||||
const size_t x_index = 0;
|
||||
return CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, x_index);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void StridedSlice::set_begin_mask(const int64_t begin_mask) {
|
||||
void StridedSlice::set_begin_mask(int64_t begin_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
|
||||
(void)this->AddAttr(kBeginMask, MakeValue(begin_mask));
|
||||
}
|
||||
|
@ -376,7 +376,7 @@ int64_t StridedSlice::get_begin_mask() const {
|
|||
auto value_ptr = GetAttr(kBeginMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_end_mask(const int64_t end_mask) {
|
||||
void StridedSlice::set_end_mask(int64_t end_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
|
||||
(void)this->AddAttr(kEndMask, MakeValue(end_mask));
|
||||
}
|
||||
|
@ -384,7 +384,7 @@ int64_t StridedSlice::get_end_mask() const {
|
|||
auto value_ptr = GetAttr(kEndMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) {
|
||||
void StridedSlice::set_ellipsis_mask(int64_t ellipsis_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
|
||||
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
|
||||
std::ostringstream buffer;
|
||||
|
@ -398,7 +398,7 @@ int64_t StridedSlice::get_ellipsis_mask() const {
|
|||
auto value_ptr = GetAttr(kEllipsisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) {
|
||||
void StridedSlice::set_new_axis_mask(int64_t new_axis_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
|
||||
(void)this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
|
||||
}
|
||||
|
@ -406,7 +406,7 @@ int64_t StridedSlice::get_new_axis_mask() const {
|
|||
auto value_ptr = GetAttr(kNewAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
|
||||
void StridedSlice::set_shrink_axis_mask(int64_t shrink_axis_mask) {
|
||||
(void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
|
||||
(void)this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
|
||||
}
|
||||
|
@ -414,8 +414,8 @@ int64_t StridedSlice::get_shrink_axis_mask() const {
|
|||
auto value_ptr = GetAttr(kShrinkAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
|
||||
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
|
||||
void StridedSlice::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
|
||||
int64_t shrink_axis_mask) {
|
||||
this->set_begin_mask(begin_mask);
|
||||
this->set_end_mask(end_mask);
|
||||
this->set_ellipsis_mask(ellipsis_mask);
|
||||
|
|
|
@ -38,18 +38,18 @@ class MS_CORE_API StridedSlice : public PrimitiveC {
|
|||
~StridedSlice() = default;
|
||||
MS_DECLARE_PARENT(StridedSlice, PrimitiveC);
|
||||
/// \brief Init. Refer to the parameters of python API @ref mindspore.ops.StridedSlice for the inputs.
|
||||
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
|
||||
const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0);
|
||||
void Init(int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0,
|
||||
int64_t shrink_axis_mask = 0);
|
||||
/// \brief Set begin_mask.
|
||||
void set_begin_mask(const int64_t begin_mask);
|
||||
void set_begin_mask(int64_t begin_mask);
|
||||
/// \brief Set end_mask.
|
||||
void set_end_mask(const int64_t end_mask);
|
||||
void set_end_mask(int64_t end_mask);
|
||||
/// \brief Set ellipsis_mask.
|
||||
void set_ellipsis_mask(const int64_t ellipsis_mask);
|
||||
void set_ellipsis_mask(int64_t ellipsis_mask);
|
||||
/// \brief Set new_axis_mask.
|
||||
void set_new_axis_mask(const int64_t new_axis_mask);
|
||||
void set_new_axis_mask(int64_t new_axis_mask);
|
||||
/// \brief Set shrink_axis_mask.
|
||||
void set_shrink_axis_mask(const int64_t shrink_axis_mask);
|
||||
void set_shrink_axis_mask(int64_t shrink_axis_mask);
|
||||
/// \brief Get begin_mask.
|
||||
///
|
||||
/// \return begin_mask.
|
||||
|
|
|
@ -378,28 +378,6 @@ void CheckAndConvertUtils::CheckInputArgs(const std::vector<AbstractBasePtr> &in
|
|||
}
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
|
||||
const std::string &prim_name) {
|
||||
if (input_args.size() <= index) {
|
||||
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index
|
||||
<< "] is out of the input number " << input_args.size();
|
||||
}
|
||||
auto input_arg = input_args[index];
|
||||
if (input_arg == nullptr) {
|
||||
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is nullptr.";
|
||||
}
|
||||
auto base_type = input_arg->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(base_type);
|
||||
if (!base_type->isa<TensorType>()) {
|
||||
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is not a tensor.";
|
||||
}
|
||||
auto tensor_type = base_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto type = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
return type;
|
||||
}
|
||||
|
||||
ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) {
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
if (!shape->isa<abstract::Shape>()) {
|
||||
|
@ -416,8 +394,8 @@ ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &sha
|
|||
|
||||
abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &prim_name,
|
||||
const std::vector<AbstractBasePtr> &input_args,
|
||||
int64_t index) {
|
||||
auto abstract = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, LongToSize(index));
|
||||
size_t index) {
|
||||
auto abstract = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, index);
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
auto base_shape = abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
|
@ -429,6 +407,28 @@ abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &
|
|||
return shape;
|
||||
}
|
||||
|
||||
TypePtr CheckAndConvertUtils::GetTensorInputType(const std::string &prim_name,
|
||||
const std::vector<AbstractBasePtr> &input_args, size_t index) {
|
||||
if (input_args.size() <= index) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the index " << index << " is out of the input number "
|
||||
<< input_args.size();
|
||||
}
|
||||
auto input_arg = input_args[index];
|
||||
if (input_arg == nullptr) {
|
||||
MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
|
||||
}
|
||||
auto base_type = input_arg->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(base_type);
|
||||
if (!base_type->isa<TensorType>()) {
|
||||
MS_EXCEPTION(ValueError) << "The " << index << "'s input type of " << prim_name << " is not Tensor.";
|
||||
}
|
||||
auto tensor_type = base_type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto type = tensor_type->element();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
return type;
|
||||
}
|
||||
|
||||
void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, const string &,
|
||||
int64_t value, const string &prim_name, ExceptionType) {
|
||||
auto iter = kCompareMap<float>.find(compare_type);
|
||||
|
|
|
@ -219,7 +219,9 @@ class CheckAndConvertUtils {
|
|||
|
||||
static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape);
|
||||
static abstract::ShapePtr GetTensorInputShape(const std::string &prim_name,
|
||||
const std::vector<AbstractBasePtr> &input_args, int64_t index);
|
||||
const std::vector<AbstractBasePtr> &input_args, size_t index);
|
||||
static TypePtr GetTensorInputType(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args,
|
||||
size_t index);
|
||||
static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type,
|
||||
const std::string &value_name, int64_t value, const std::string &prim_name = "",
|
||||
ExceptionType exception_type = ValueError);
|
||||
|
@ -313,8 +315,6 @@ class CheckAndConvertUtils {
|
|||
static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);
|
||||
static void CheckInputArgs(const std::vector<AbstractBasePtr> &input_args, const CompareEnum compare_operator,
|
||||
const int64_t match_value, const std::string &prim_name);
|
||||
static TypePtr GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
|
||||
const std::string &prim_name);
|
||||
static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list);
|
||||
|
||||
private:
|
||||
|
|
|
@ -55,7 +55,9 @@ from .batch_matmul_ds import _batch_matmul_ds_tbe
|
|||
from .batchnorm import _batch_norm_tbe
|
||||
from .batchnorm_grad import _batch_norm_grad_tbe
|
||||
from .bias_add import _bias_add_tbe
|
||||
from .bias_add_ds import _bias_add_ds_tbe
|
||||
from .bias_add_grad import _bias_add_grad_tbe
|
||||
from .bias_add_grad_ds import _bias_add_grad_ds_tbe
|
||||
from .cast import _cast_tbe
|
||||
from .cast_ds import _cast_ds_tbe
|
||||
from .conv2d import _conv2d_tbe
|
||||
|
@ -113,6 +115,7 @@ from .scatter_nd_sub import _scatter_nd_sub_tbe
|
|||
from .scatter_non_aliasing_add import _scatter_non_aliasing_add_tbe
|
||||
from .reduce_mean import _reduce_mean_tbe
|
||||
from .tile import _tile_tbe
|
||||
from .tile_ds import _tile_ds_tbe
|
||||
from .atomic_addr_clean import _atomic_addr_clean_tbe
|
||||
from .gather_v2 import _gather_v2_tbe
|
||||
from .gather_v2_ds import _gather_v2_ds_tbe
|
||||
|
@ -186,6 +189,7 @@ from .sparse_apply_proximal_adagrad_ds import _sparse_apply_proximal_adagrad_ds
|
|||
from .apply_proximal_adagrad import _apply_proximal_adagrad
|
||||
from .transpose import _transpose_tbe
|
||||
from .transpose_d import _transpose_d_tbe
|
||||
from .transpose_ds import _transpose_ds_tbe
|
||||
from .truncate_div import _truncate_div_tbe
|
||||
from .truncate_mod import _truncate_mod_tbe
|
||||
from .unsorted_segment_sum import _unsorted_segment_sum_tbe
|
||||
|
@ -352,6 +356,7 @@ from .gru_v2_hidden_grad_cell import _gru_v2_hidden_grad_cell_tbe
|
|||
from .lstm_input_grad import _lstm_input_grad_tbe
|
||||
from .confusion_matrix import _confusion_matrix_tbe
|
||||
from .broadcast_to import _broadcast_to_tbe
|
||||
from .broadcast_to_ds import _broadcast_to_ds_tbe
|
||||
from .strided_read import _strided_read_tbe
|
||||
from .strided_write import _strided_write_tbe
|
||||
from .range import _range_tbe
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BiasAdd op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bias_add_grad_op_info = TBERegOp("BiasAdd") \
|
||||
.fusion_type("COMMREDUCE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bias_add.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bias_add") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("format", "required", "str", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "bias", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.is_dynamic_format(True) \
|
||||
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bias_add_grad_op_info)
|
||||
def _bias_add_ds_tbe():
|
||||
"""BiasAdd TBE register"""
|
||||
return
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BiasAddGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
|
||||
.fusion_type("COMMREDUCE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("bias_add_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("bias_add_grad") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.attr("format", "required", "str", "all") \
|
||||
.input(0, "output_backprop", False, "required", "all") \
|
||||
.output(0, "output", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NHWC) \
|
||||
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NHWC) \
|
||||
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NHWC) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(bias_add_grad_op_info)
|
||||
def _bias_add_grad_ds_tbe():
|
||||
"""BiasAddGrad TBE register"""
|
||||
return
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""BroadcastTo op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
broadcast_to_op_info = TBERegOp("DynamicBroadcastTo") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("broadcast_to.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("broadcast_to") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "shape", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(broadcast_to_op_info)
|
||||
def _broadcast_to_ds_tbe():
|
||||
"""BroadcastTo TBE register"""
|
||||
return
|
|
@ -34,12 +34,8 @@ matmul_op_info = TBERegOp("MatMul") \
|
|||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default,
|
||||
DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default,
|
||||
DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default,
|
||||
DataType.F16_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default,
|
||||
DataType.F32_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dynamic Tile op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
tile_op_info = TBERegOp("Tile") \
|
||||
.fusion_type("ELEMWISE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("tile.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("tile") \
|
||||
.partial_flag(True) \
|
||||
.dynamic_shape(True) \
|
||||
.input(0, "x1", False, "required", "all") \
|
||||
.input(1, "multiples", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(tile_op_info)
|
||||
def _tile_ds_tbe():
|
||||
"""Tile TBE register"""
|
||||
return
|
|
@ -1403,3 +1403,29 @@ class SliceGetItem(Primitive):
|
|||
if value == "step":
|
||||
return slice_value.step
|
||||
raise AttributeError("\'slice\' object has no attribute {}".format(value))
|
||||
|
||||
|
||||
class DynamicBroadcastTo(Primitive):
|
||||
"""
|
||||
Broadcasts input tensor to a given shape.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
|
||||
float16, float32, int32, int8, uint8.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
- **shape** (Tensor): The target shape to broadcast.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the given `shape` and the same data type as `input_x`.
|
||||
|
||||
Raises:
|
||||
ValueError: if the target and input shapes are incompatible.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize DynamicBroadcastTo"""
|
||||
self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y'])
|
||||
|
|
|
@ -521,7 +521,7 @@ class Reshape(PrimitiveWithInfer):
|
|||
neg_index = i
|
||||
else:
|
||||
dim_prod *= shp_i
|
||||
arr_prod = np.prod(x_shp)
|
||||
|
||||
if -1 in x_shp:
|
||||
if 'max_shape' in x:
|
||||
x_max_shape = x['max_shape']
|
||||
|
@ -545,6 +545,7 @@ class Reshape(PrimitiveWithInfer):
|
|||
'max_shape': tuple(max_shape),
|
||||
'min_shape': tuple(min_shape)}
|
||||
else:
|
||||
arr_prod = np.prod(x_shp)
|
||||
if dim_prod <= 0:
|
||||
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
|
||||
f"the value of 'input_shape' is {shape_v}. "
|
||||
|
|
|
@ -451,9 +451,8 @@ class _Reduce(PrimitiveWithInfer):
|
|||
axis_shape = axis_shape_list[0]
|
||||
if axis_shape == -1 and not self.keep_dims:
|
||||
out_shape = np.array([-2]).tolist()
|
||||
output_min_shape = np.ones_like(input_shp).tolist()
|
||||
output_max_shape = max_v * np.ones_like(input_shp)
|
||||
output_max_shape = output_max_shape.tolist()
|
||||
output_min_shape = input_x['min_shape']
|
||||
output_max_shape = input_x['max_shape']
|
||||
elif not self.keep_dims:
|
||||
out_shape = -1 * np.ones_like(input_shp[:-axis_shape])
|
||||
out_shape = out_shape.tolist()
|
||||
|
@ -467,12 +466,12 @@ class _Reduce(PrimitiveWithInfer):
|
|||
output_max_shape = max_v * np.ones_like(input_shp)
|
||||
output_max_shape = output_max_shape.tolist()
|
||||
else:
|
||||
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
|
||||
output_max_shape = _infer_shape_reduce(input_x['max_shape'], axis_v, self.keep_dims, self.name)
|
||||
output_min_shape = _infer_shape_reduce(input_x['min_shape'], axis_v, self.keep_dims, self.name)
|
||||
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
|
||||
else:
|
||||
if axis_v is None:
|
||||
raise ValueError(f"For {self.name}, the 'axis' cannot be None.")
|
||||
raise ValueError(f"For {self.name}, axis must be const, its value cannot be None.")
|
||||
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
|
||||
output_max_shape = out_shape
|
||||
output_min_shape = out_shape
|
||||
|
|
|
@ -160,8 +160,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|||
validator.check_value_type('network', network, nn.Cell)
|
||||
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt))
|
||||
if not isinstance(level, str):
|
||||
raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \
|
||||
but got type {}.".format(type(level)))
|
||||
raise TypeError(f"The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], "
|
||||
f"but got type {str(type(level))}.")
|
||||
validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN)
|
||||
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)
|
||||
|
||||
|
|
Loading…
Reference in New Issue