UniqueWithPad dyn shape support

update

update

update

update

update
This commit is contained in:
hw_hz 2022-11-01 15:38:20 +08:00
parent 900c7bec92
commit 12f44bcd1e
11 changed files with 205 additions and 128 deletions

View File

@ -19,6 +19,37 @@
namespace mindspore {
namespace kernel {
int UniqueWithPadCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kUniqueWithPadInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kUniqueWithPadOutputsNum, kernel_name_);
auto input_shape = inputs[0]->GetShapeVector();
input_size_ = input_shape[0];
batch_size_ = 1;
if (batch_rank_ > 0) {
auto pad_shape = inputs[kPadNumIndex]->GetShapeVector();
auto pad_nums = std::accumulate(pad_shape.begin(), pad_shape.end(), 1, std::multiplies<int64_t>());
batch_size_ =
std::accumulate(input_shape.begin(), input_shape.begin() + batch_rank_, 1, std::multiplies<int64_t>());
input_size_ = input_shape[input_shape.size() - 1];
if (pad_nums != static_cast<int64_t>(batch_size_)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the elements num of input 'pad' must be equal to input 'x' batch size, "
"but got the elements num of input 'pad': "
<< Vector2Str(pad_shape) << " and input 'x' batch size: " << batch_size_;
}
}
workspace_size_list_.clear();
(void)workspace_size_list_.emplace_back(input_size_ * sizeof(int64_t));
(void)workspace_size_list_.emplace_back(input_size_ * sizeof(int64_t));
(void)workspace_size_list_.emplace_back(input_size_ * sizeof(int64_t));
return KRET_OK;
}
bool UniqueWithPadCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
@ -41,6 +72,24 @@ bool UniqueWithPadCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
return true;
}
template <typename T>
void UniqueWithPadCpuKernelMod::PadOutput(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
const std::vector<size_t> &start) {
if (inputs.size() < kUniqueWithPadInputsNum || outputs.size() < kUniqueWithPadOutputsNum) {
return;
}
auto pad_num_p = reinterpret_cast<T *>(inputs[1]->addr);
auto *out = reinterpret_cast<T *>(outputs[0]->addr);
for (size_t batch_i = 0; batch_i < batch_size_; batch_i++) {
T pad_num = *pad_num_p;
for (size_t i = start[batch_i]; i < input_size_; ++i) {
out[i] = pad_num;
}
pad_num_p++;
out += input_size_;
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, UniqueWithPad, UniqueWithPadCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -36,28 +36,10 @@ class UniqueWithPadCpuKernelMod : public UniqueCpuKernelMod {
public:
UniqueWithPadCpuKernelMod() = default;
~UniqueWithPadCpuKernelMod() override = default;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kUniqueWithPadInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kUniqueWithPadOutputsNum, kernel_name_);
int ret = UniqueCpuKernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
if (ret != 0) {
return ret;
}
is_need_retrieve_output_shape_ = false;
if (batch_rank_ > 0) {
auto pad_shape = inputs[kPadNumIndex]->GetShapeVector();
auto pad_nums = std::accumulate(pad_shape.begin(), pad_shape.end(), 1, std::multiplies<int64_t>());
if (pad_nums != static_cast<int64_t>(batch_size_)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the elements num of input 'pad' must be equal to input 'x' batch size, "
"but got the elements num of input 'pad': "
<< Vector2Str(pad_shape) << " and input 'x' batch size: " << batch_size_;
}
}
return ret;
}
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
@ -84,21 +66,7 @@ class UniqueWithPadCpuKernelMod : public UniqueCpuKernelMod {
private:
template <typename T>
void PadOutput(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
const std::vector<size_t> &start) {
if (inputs.size() < kUniqueWithPadInputsNum || outputs.size() < kUniqueWithPadOutputsNum) {
return;
}
auto pad_num_p = reinterpret_cast<T *>(inputs[1]->addr);
auto *out = reinterpret_cast<T *>(outputs[0]->addr);
for (size_t batch_i = 0; batch_i < batch_size_; batch_i++) {
T pad_num = *pad_num_p;
for (size_t i = start[batch_i]; i < input_size_; ++i) {
out[i] = pad_num;
}
pad_num_p++;
out += input_size_;
}
}
const std::vector<size_t> &start);
};
} // namespace kernel
} // namespace mindspore

View File

@ -30,6 +30,7 @@ std::unique_ptr<cukernel::GpuKernelHelperBase> CreateUniqueWithPadKernelPtr(cons
const uint32_t &device_id) {
return std::make_unique<cukernel::UniqueWithPadHelperGpuKernel<T, S>>(kernel_name, device_id);
}
using UniqueWithPadPtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
@ -62,6 +63,7 @@ const std::vector<std::pair<KernelAttr, UniqueWithPadPtrCreatorFunc>> kernel_att
bool UniqueWithPadGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
base_operator_ = base_operator;
kernel_name_ = base_operator->name();
auto batch_rank = base_operator->get_batch_rank();
@ -69,13 +71,23 @@ bool UniqueWithPadGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
return false;
}
batch_rank_ = static_cast<size_t>(batch_rank);
inputs_ = inputs;
outputs_ = outputs;
auto [is_match, index] = MatchKernelAttr(GetKernelAttrFromTensors(inputs, outputs), GetOpSupport());
if (!is_match) {
return false;
}
helper_ptr_ = kernel_attr[index].second(kernel_name_, device_id_);
return true;
}
int UniqueWithPadGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
inputs_ = inputs;
outputs_ = outputs;
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
constexpr size_t kUniqueWithPadInputNum = 2;
@ -102,10 +114,11 @@ bool UniqueWithPadGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
<< Vector2Str(pad_shape) << " and input 'x' batch size: " << batch_size;
}
}
is_null_input_ = CHECK_SHAPE_NULL(shape, kernel_name_, "input");
if (is_null_input_) {
InitSizeLists();
return true;
return KRET_OK;
}
input_shapes.emplace_back(inputs[0]->GetDeviceShapeAdaptively());
@ -113,18 +126,7 @@ bool UniqueWithPadGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const
helper_ptr_->CalMemSize(input_shapes, output_shapes);
InitSizeLists();
is_need_retrieve_output_shape_ = false;
if (!is_input_dynamic_shape_.has_value()) {
bool is_input_dynamic_shape = false;
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (std::any_of(input_shape.begin(), input_shape.end(), [](int64_t dim) { return dim < 0; })) {
is_input_dynamic_shape = true;
break;
}
}
is_input_dynamic_shape_ = is_input_dynamic_shape;
}
return true;
return KRET_OK;
}
std::vector<KernelAttr> UniqueWithPadGpuKernelMod::GetOpSupport() {

View File

@ -29,17 +29,17 @@ namespace mindspore {
namespace kernel {
class UniqueWithPadGpuKernelMod : public UniqueGpuKernelMod {
public:
UniqueWithPadGpuKernelMod() {
KernelMod::kernel_name_ = "UniqueWithPad";
ResetResource();
}
UniqueWithPadGpuKernelMod() { ResetResource(); }
~UniqueWithPadGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
protected:
void SyncData() override{};
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel

View File

@ -139,8 +139,6 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUniqueWithPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplOCRRecognitionPreHandle(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -264,60 +264,6 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
return std::make_shared<AbstractTuple>(elements);
}
AbstractBasePtr InferImplUniqueWithPad(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// inputs: a 1-d Tensor
const std::string op_name = primitive->name();
constexpr size_t kUniqueWithPadInputNum = 2;
constexpr size_t kPadIndex = 1;
CheckArgsSize(op_name, args_spec_list, kUniqueWithPadInputNum);
AbstractTensorPtr input = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto shape = input->shape();
MS_EXCEPTION_IF_NULL(shape);
size_t batch_rank = 0;
if (primitive->HasAttr(ops::kBatchRank)) {
auto value_ptr = primitive->GetAttr(ops::kBatchRank);
batch_rank = GetValue<int64_t>(value_ptr);
}
if (batch_rank != 0) {
(void)CheckAndConvertUtils::CheckInteger("input_shape size", shape->shape().size(), kEqual, batch_rank + 1,
op_name);
AbstractTensorPtr pad = CheckArg<AbstractTensor>(op_name, args_spec_list, kPadIndex);
auto pad_shape = pad->shape();
MS_EXCEPTION_IF_NULL(pad_shape);
auto pad_num = std::accumulate(pad_shape->shape().begin(), pad_shape->shape().end(), 1, std::multiplies<int64_t>());
auto input_batch =
std::accumulate(shape->shape().begin(), shape->shape().begin() + batch_rank, 1, std::multiplies<int64_t>());
(void)CheckAndConvertUtils::CheckInteger("elements num of input 'pad'", pad_num, kEqual, input_batch, op_name);
} else {
if (shape->shape().size() != 1) {
MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1.";
}
}
// Currently we choose the same data type as input for the idx.
TypePtr ids_idx_type = kInt32;
MS_EXCEPTION_IF_NULL(input->element());
MS_EXCEPTION_IF_NULL(input->element()->GetTypeTrack());
if (input->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
ids_idx_type = kInt64;
}
ShapeVector idx_shape = shape->shape();
ShapeVector idx_min_shape = shape->min_shape();
if (idx_min_shape.empty()) {
idx_min_shape = shape->shape();
}
ShapeVector idx_max_shape = shape->max_shape();
if (idx_max_shape.empty()) {
idx_max_shape = shape->shape();
}
auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape);
ids_idx->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape));
AbstractBasePtr ids = input->Broaden();
return std::make_shared<AbstractTuple>(AbstractBasePtrList({ids, ids_idx}));
}
AbstractBasePtr InferImplPadAndShift(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// inputs: a 1-d Tensor

View File

@ -280,7 +280,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimArrayToScalar, R{InferImplArrayToScalar, nullptr, true}},
{prim::kPrimBroadcastShape, R{InferImplBroadCastShape, nullptr, true}},
{prim::kPrimUnique, R{InferImplUnique, nullptr, true}},
{prim::kPrimUniqueWithPad, R{InferImplUniqueWithPad, nullptr, true}},
{prim::kPrimUniqueGrad, R{InferImplUniqueGrad, nullptr, true}},
{prim::kPrimEmbeddingLookup, R{InferImplEmbeddingLookup, nullptr, true}},
{prim::kPrimSparseGatherV2, R{InferImplGatherV2, nullptr, true}},

View File

@ -15,12 +15,90 @@
*/
#include "ops/unique_with_pad.h"
#include "ops/primitive_c.h"
#include <functional>
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/op_utils.h"
#include "ops/op_name.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kUniqueWithPadInputsNum = 2;
constexpr size_t kUniqueWithPadOutputsNum = 2;
abstract::TupleShapePtr UniqueWithPadInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto pad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto is_dynamic = IsDynamic(x_shape) || IsDynamic(pad_shape);
size_t batch_rank = 0;
if (primitive->HasAttr(ops::kBatchRank)) {
auto value_ptr = primitive->GetAttr(ops::kBatchRank);
batch_rank = GetValue<int64_t>(value_ptr);
}
if (!IsDynamicRank(x_shape)) {
(void)CheckAndConvertUtils::CheckInteger("input_shape_size", x_shape.size(), kEqual, batch_rank + 1, prim_name);
}
constexpr int64_t kNumZero = 0;
if (!is_dynamic && batch_rank != kNumZero) {
auto pad_num = std::accumulate(pad_shape.begin(), pad_shape.end(), 1, std::multiplies<int64_t>());
auto input_batch = std::accumulate(x_shape.begin(), x_shape.begin() + batch_rank, 1, std::multiplies<int64_t>());
(void)CheckAndConvertUtils::CheckInteger("elements num of input 'pad'", pad_num, kEqual, input_batch, prim_name);
}
auto x_shape_ptr = std::make_shared<abstract::Shape>(x_shape);
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{x_shape_ptr, x_shape_ptr});
}
TuplePtr UniqueWithPadInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto x_type = input_args[0]->BuildType();
std::set<TypePtr> valid_types = {kInt32, kInt64, kFloat32};
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
TypePtr y_type = x_type;
TypePtr idx_type = kInt32;
abstract::AbstractTensorPtr x_ptr = input_args.at(kInputIndex0)->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(x_ptr->element());
MS_EXCEPTION_IF_NULL(x_ptr->element()->GetTypeTrack());
if (x_ptr->element()->GetTypeTrack()->type_id() == TypeId::kNumberTypeInt64) {
idx_type = kInt64;
}
return std::make_shared<Tuple>(std::vector<TypePtr>{y_type, idx_type});
}
} // namespace
AbstractBasePtr UniqueWithPadInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (auto &input : input_args) {
MS_EXCEPTION_IF_NULL(input);
}
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kUniqueWithPadInputsNum, prim_name);
auto infer_type = UniqueWithPadInferType(primitive, input_args);
auto infer_shape = UniqueWithPadInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(UniqueWithPad, BaseOperator);
REGISTER_PRIMITIVE_C(kNameUniqueWithPad, UniqueWithPad);
REGISTER_PRIMITIVE_EVAL_IMPL(UniqueWithPad, prim::kPrimUniqueWithPad, UniqueWithPadInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -16,9 +16,14 @@
#ifndef MINDSPORE_CORE_OPS_UNIQUE_WITH_PAD_H_
#define MINDSPORE_CORE_OPS_UNIQUE_WITH_PAD_H_
#include <memory>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/base_operator.h"
#include "utils/check_convert_utils.h"
#include "mindapi/base/types.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ops {
@ -34,6 +39,10 @@ class MIND_API UniqueWithPad : public BaseOperator {
/// \brief Init.
void Init() const {}
};
AbstractBasePtr UniqueWithPadInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimUniqueWithPadPtr = std::shared_ptr<UniqueWithPad>;
} // namespace ops
} // namespace mindspore

View File

@ -1246,7 +1246,7 @@ class Padding(Primitive):
self.pad_dim_size = pad_dim_size
class UniqueWithPad(PrimitiveWithCheck):
class UniqueWithPad(Primitive):
"""
Returns unique elements and relative indexes in 1-D tensor, filled with padding num.
@ -1273,14 +1273,7 @@ class UniqueWithPad(PrimitiveWithCheck):
@prim_attr_register
def __init__(self):
"""init UniqueWithPad"""
def __check__(self, x, pad_num):
type_list = [mstype.int32, mstype.int64, mstype.float32]
validator.check_tensor_dtype_valid("x", x['dtype'], type_list, self.name)
if not hasattr(self, 'batch_rank'):
validator.check_subclass("pad_num", pad_num['dtype'], type_list, self.name)
x_shape = list(x['shape'])
validator.check("rank of x", len(x_shape), '', 1, Rel.EQ, self.name)
self.init_prim_io_names(inputs=['x', 'pad_num'], outputs=['y', 'idx'])
class Split(Primitive):

View File

@ -26,6 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Net(nn.Cell):
def __init__(self, pad_num):
super(Net, self).__init__()
self.uniq = P.UniqueWithPad()
@ -35,6 +36,35 @@ class Net(nn.Cell):
return self.uniq(x, self.pad_num)
def dyn_case():
net = Net(0)
x_dyn = Tensor(shape=[None], dtype=mstype.int32)
net.set_inputs(x_dyn)
x = Tensor(np.array([1, 1, 2, 2, 3, 3, 4, 5]), dtype=mstype.int32)
out = net(x)
expect_shape = (8,)
for i in range(2):
assert out[i].asnumpy().shape == expect_shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net_dyn():
"""
Feature: test uniquewithpad in cpu.
Description: test the ops in dynamic shape.
Expectation: expect correct shape result.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
dyn_case()
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
dyn_case()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -107,12 +137,17 @@ def test_unique_with_pad_vmap():
def cal_unique_with_pad(x):
return P.UniqueWithPad()(x, -1)
x = Tensor(np.array([[[1, 2, 5, 2], [1, 2, 5, 2]], [[1, 2, 5, 2], [1, 2, 5, 2]]]).astype(np.int32))
x = Tensor(
np.array([[[1, 2, 5, 2], [1, 2, 5, 2]],
[[1, 2, 5, 2], [1, 2, 5, 2]]]).astype(np.int32))
vmap_unique_with_pad = vmap(vmap(cal_unique_with_pad, in_axes=0), in_axes=0)
vmap_unique_with_pad = vmap(vmap(cal_unique_with_pad, in_axes=0),
in_axes=0)
outputs = vmap_unique_with_pad(x)
expect0 = np.array([[[1, 2, 5, -1], [1, 2, 5, -1]], [[1, 2, 5, -1], [1, 2, 5, -1]]]).astype(np.int32)
expect1 = np.array([[[0, 1, 2, 1], [0, 1, 2, 1]], [[0, 1, 2, 1], [0, 1, 2, 1]]]).astype(np.int32)
expect0 = np.array([[[1, 2, 5, -1], [1, 2, 5, -1]],
[[1, 2, 5, -1], [1, 2, 5, -1]]]).astype(np.int32)
expect1 = np.array([[[0, 1, 2, 1], [0, 1, 2, 1]],
[[0, 1, 2, 1], [0, 1, 2, 1]]]).astype(np.int32)
assert np.allclose(outputs[0].asnumpy(), expect0)
assert np.allclose(outputs[1].asnumpy(), expect1)