Add support of scalars with AggregateFunction() type. [#CLICKHOUSE-2845]

Fixed segfaults for arrayReduce. [#CLICKHOUSE-2787]
This commit is contained in:
Vitaliy Lyudvichenko 2017-02-15 14:23:38 +03:00 committed by alexey-milovidov
parent c52670776c
commit 97c4211409
15 changed files with 357 additions and 29 deletions

View File

@ -120,6 +120,11 @@ public:
nested_func->insertResultInto(place, to);
}
bool allocatesMemoryInArena() const override
{
return nested_func->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
{
static_cast<const AggregateFunctionArray &>(*that).add(place, columns, row_num, arena);

View File

@ -103,6 +103,11 @@ public:
nested_func->insertResultInto(place, to);
}
bool allocatesMemoryInArena() const override
{
return nested_func->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
{
static_cast<const AggregateFunctionIf &>(*that).add(place, columns, row_num, arena);

View File

@ -104,6 +104,11 @@ public:
nested_func->insertResultInto(place, to);
}
bool allocatesMemoryInArena() const override
{
return nested_func->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)
{
static_cast<const AggregateFunctionMerge &>(*that).add(place, columns, row_num, arena);

View File

@ -255,6 +255,11 @@ public:
nested_function->add(nestedPlace(place), nested_columns, row_num, arena);
}
bool allocatesMemoryInArena() const override
{
return nested_function->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place,
const IColumn ** columns, size_t row_num, Arena * arena)
{

View File

@ -100,6 +100,11 @@ public:
/// Аггрегатная функция или состояние аггрегатной функции.
bool isState() const override { return true; }
bool allocatesMemoryInArena() const override
{
return nested_func->allocatesMemoryInArena();
}
AggregateFunctionPtr getNestedFunction() const { return nested_func_owner; }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena)

View File

@ -160,11 +160,21 @@ public:
insertMergeFrom(src, n);
}
void insertFrom(ConstAggregateDataPtr place)
{
insertDefault();
insertMergeFrom(place);
}
/// Merge state at last row with specified state in another column.
void insertMergeFrom(ConstAggregateDataPtr place)
{
func->merge(getData().back(), place, &createOrGetArena());
}
void insertMergeFrom(const IColumn & src, size_t n)
{
Arena & arena = createOrGetArena();
func->merge(getData().back(), static_cast<const ColumnAggregateFunction &>(src).getData()[n], &arena);
insertMergeFrom(static_cast<const ColumnAggregateFunction &>(src).getData()[n]);
}
Arena & createOrGetArena()

View File

@ -0,0 +1,202 @@
#pragma once
#include <DB/Columns/ColumnConst.h>
#include <DB/DataTypes/DataTypeAggregateFunction.h>
namespace DB
{
class ColumnConstAggregateFunction : public IColumnConst
{
public:
ColumnConstAggregateFunction(size_t size, const Field & value_, const DataTypePtr & data_type_)
: data_type(data_type_), value(value_), s(size)
{
}
String getName() const override
{
return "ColumnConst<ColumnAggregateFunction>";
}
bool isConst() const override
{
return true;
}
ColumnPtr convertToFullColumnIfConst() const override
{
auto res = std::make_shared<ColumnAggregateFunction>(getAggregateFunction());
for (size_t i = 0; i < s; ++i)
res->insert(value);
return res;
}
ColumnPtr convertToFullColumn() const override
{
return convertToFullColumnIfConst();
}
ColumnPtr cloneResized(size_t new_size) const override
{
return std::make_shared<ColumnConstAggregateFunction>(new_size, value, data_type);
}
size_t size() const override
{
return s;
}
Field operator[](size_t n) const override
{
/// NOTE: there are no out of bounds check (like in ColumnConstBase)
return value;
}
void get(size_t n, Field & res) const override
{
/// NOTE: there are no out of bounds check (like in ColumnConstBase)
res = value;
}
StringRef getDataAt(size_t n) const override
{
return value.get<const String &>();
}
void insert(const Field & x) override
{
/// NOTE: Cannot check source function of x
if (value != x)
throw Exception("Cannot insert different element into constant column " + getName(),
ErrorCodes::CANNOT_INSERT_ELEMENT_INTO_CONSTANT_COLUMN);
++s;
}
void insertRangeFrom(const IColumn & src, size_t start, size_t length) override
{
if (!equalsFuncAndValue(src))
throw Exception("Cannot insert different element into constant column " + getName(),
ErrorCodes::CANNOT_INSERT_ELEMENT_INTO_CONSTANT_COLUMN);
s += length;
}
void insertData(const char * pos, size_t length) override
{
throw Exception("Method insertData is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void insertDefault() override
{
++s;
}
void popBack(size_t n) override
{
s -= n;
}
StringRef serializeValueIntoArena(size_t n, Arena & arena, char const *& begin) const override
{
throw Exception("Method serializeValueIntoArena is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
const char * deserializeAndInsertFromArena(const char * pos) override
{
throw Exception("Method deserializeAndInsertFromArena is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void updateHashWithValue(size_t n, SipHash & hash) const override
{
throw Exception("Method updateHashWithValue is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
ColumnPtr filter(const Filter & filt, ssize_t result_size_hint) const override
{
if (s != filt.size())
throw Exception("Size of filter doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
return std::make_shared<ColumnConstAggregateFunction>(countBytesInFilter(filt), value, data_type);
}
ColumnPtr permute(const Permutation & perm, size_t limit) const override
{
if (limit == 0)
limit = s;
else
limit = std::min(s, limit);
if (perm.size() < limit)
throw Exception("Size of permutation is less than required.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
return std::make_shared<ColumnConstAggregateFunction>(limit, value, data_type);
}
int compareAt(size_t n, size_t m, const IColumn & rhs_, int nan_direction_hint) const override
{
throw Exception("Method compareAt is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
void getPermutation(bool reverse, size_t limit, Permutation & res) const override
{
res.resize(s);
for (size_t i = 0; i < s; ++i)
res[i] = i;
}
ColumnPtr replicate(const Offsets_t & offsets) const override
{
if (s != offsets.size())
throw Exception("Size of offsets doesn't match size of column.", ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH);
size_t replicated_size = 0 == s ? 0 : offsets.back();
return std::make_shared<ColumnConstAggregateFunction>(replicated_size, value, data_type);
}
void getExtremes(Field & min, Field & max) const override
{
min = value;
max = value;
}
size_t byteSize() const override
{
return sizeof(value) + sizeof(s);
}
size_t allocatedSize() const override
{
return byteSize();
}
private:
DataTypePtr data_type;
Field value;
size_t s;
AggregateFunctionPtr getAggregateFunction() const
{
return typeid_cast<const DataTypeAggregateFunction &>(*data_type).getFunction();
}
bool equalsFuncAndValue(const IColumn & rhs) const
{
auto rhs_const = dynamic_cast<const ColumnConstAggregateFunction *>(&rhs);
return !rhs_const && equalsFuncAndValue(*rhs_const);
}
bool equalsFuncAndValue(const ColumnConstAggregateFunction & rhs) const
{
/// Shallow check, no args, no params
return value == rhs.value
&& getAggregateFunction()->getName() == rhs.getAggregateFunction()->getName();
}
};
}

View File

@ -462,4 +462,13 @@ void swap(PODArray<T, INITIAL_SIZE, TAllocator, pad_right_> & lhs, PODArray<T, I
template <typename T, size_t INITIAL_SIZE = 4096, typename TAllocator = Allocator<false>>
using PaddedPODArray = PODArray<T, INITIAL_SIZE, TAllocator, 15>;
constexpr size_t integerRound(size_t value, size_t dividend)
{
return ((value + dividend - 1) / dividend) * dividend;
}
template <typename T, size_t stack_size_in_bytes>
using PODArrayWithStackMemory = PODArray<T, 0, AllocatorWithStackMemory<Allocator<false>, integerRound(stack_size_in_bytes, sizeof(T))>>;
}

View File

@ -32,6 +32,7 @@ public:
}
std::string getFunctionName() const { return function->getName(); }
AggregateFunctionPtr getFunction() const { return function; }
std::string getName() const override;
@ -62,10 +63,7 @@ public:
ColumnPtr createColumn() const override;
ColumnPtr createConstColumn(size_t size, const Field & field) const override;
Field getDefault() const override
{
throw Exception("There is no default value for AggregateFunction data type", ErrorCodes::THERE_IS_NO_DEFAULT_VALUE);
}
Field getDefault() const override;
};

View File

@ -141,6 +141,12 @@ public:
throw Exception("getSizeOfField() method is not implemented for data type " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/// Checks that two instances belong to the same type
inline bool equals(const IDataType & rhs) const
{
return getName() == rhs.getName();
}
virtual ~IDataType() {}
};

View File

@ -1498,19 +1498,11 @@ private:
};
}
/// Only trivial NULL -> NULL case
WrapperType createNullWrapper(const DataTypePtr & from_type, const DataTypeNull * to_type)
WrapperType createIdentityWrapper(const DataTypePtr &)
{
if (!typeid_cast<const DataTypeNull *>(from_type.get()))
throw Exception("Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported",
ErrorCodes::CANNOT_CONVERT_TYPE);
return [] (Block & block, const ColumnNumbers & arguments, const size_t result)
{
// just copy pointer to Null column
ColumnWithTypeAndName & res_col = block.safeGetByPosition(result);
const ColumnWithTypeAndName & src_col = block.safeGetByPosition(arguments.front());
res_col.column = src_col.column;
block.safeGetByPosition(result).column = block.safeGetByPosition(arguments.front()).column;
};
}
@ -1602,7 +1594,9 @@ private:
WrapperType prepareImpl(const DataTypePtr & from_type, const IDataType * const to_type)
{
if (const auto to_actual_type = typeid_cast<const DataTypeUInt8 *>(to_type))
if (from_type->equals(*to_type))
return createIdentityWrapper(from_type);
else if (const auto to_actual_type = typeid_cast<const DataTypeUInt8 *>(to_type))
return createWrapper(from_type, to_actual_type);
else if (const auto to_actual_type = typeid_cast<const DataTypeUInt16 *>(to_type))
return createWrapper(from_type, to_actual_type);
@ -1638,8 +1632,6 @@ private:
return createEnumWrapper(from_type, type_enum);
else if (const auto type_enum = typeid_cast<const DataTypeEnum16 *>(to_type))
return createEnumWrapper(from_type, type_enum);
else if (const auto type_null = typeid_cast<const DataTypeNull *>(to_type))
return createNullWrapper(from_type, type_null);
/// It's possible to use ConvertImplGenericFromString to convert from String to AggregateFunction,
/// but it is disabled because deserializing aggregate functions state might be unsafe.
@ -1691,7 +1683,7 @@ private:
else if (const auto type = typeid_cast<const DataTypeEnum16 *>(to_type))
monotonicity_for_range = monotonicityForType(type);
}
/// other types like FixedString, Array and Tuple have no monotonicity defined
/// other types like Null, FixedString, Array and Tuple have no monotonicity defined
}
public:

View File

@ -3,8 +3,8 @@
#include <DB/IO/WriteHelpers.h>
#include <DB/IO/ReadHelpers.h>
#include <DB/Columns/ColumnConst.h>
#include <DB/Columns/ColumnAggregateFunction.h>
#include <DB/Columns/ColumnConstAggregateFunction.h>
#include <DB/DataTypes/DataTypeAggregateFunction.h>
@ -29,8 +29,8 @@ std::string DataTypeAggregateFunction::getName() const
stream << ")";
}
for (DataTypes::const_iterator it = argument_types.begin(); it != argument_types.end(); ++it)
stream << ", " << (*it)->getName();
for (const auto & argument_type: argument_types)
stream << ", " << argument_type->getName();
stream << ")";
return stream.str();
@ -236,7 +236,33 @@ ColumnPtr DataTypeAggregateFunction::createColumn() const
ColumnPtr DataTypeAggregateFunction::createConstColumn(size_t size, const Field & field) const
{
throw Exception("Const column with aggregate function is not supported", ErrorCodes::NOT_IMPLEMENTED);
return std::make_shared<ColumnConstAggregateFunction>(size, field, clone());
}
/// Create empty state
Field DataTypeAggregateFunction::getDefault() const
{
Field field = String();
PODArrayWithStackMemory<char, 16> place_buffer(function->sizeOfData());
AggregateDataPtr place = place_buffer.data();
function->create(place);
try
{
WriteBufferFromString buffer_from_field(field.get<String &>());
function->serialize(place, buffer_from_field);
}
catch (...)
{
function->destroy(place);
throw;
}
function->destroy(place);
return field;
}

View File

@ -2723,10 +2723,6 @@ void FunctionArrayReduce::getReturnTypeAndPrerequisitesImpl(
aggregate_function = AggregateFunctionFactory().get(aggregate_function_name, argument_types);
/// Потому что владение состояниями агрегатных функций никуда не отдаётся.
if (aggregate_function->isState())
throw Exception("Using aggregate function with -State modifier in function arrayReduce is not supported", ErrorCodes::BAD_ARGUMENTS);
if (has_parameters)
aggregate_function->setParameters(params_row);
aggregate_function->setArguments(argument_types);
@ -2778,6 +2774,9 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum
block.safeGetByPosition(result).column = result_holder;
IColumn & res_col = *result_holder.get();
/// AggregateFunction's states should be inserted into column using specific way
auto res_col_aggregate_function = typeid_cast<ColumnAggregateFunction *>(&res_col);
ColumnArray::Offset_t current_offset = 0;
for (size_t i = 0; i < rows; ++i)
{
@ -2789,7 +2788,10 @@ void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & argum
for (size_t j = current_offset; j < next_offset; ++j)
agg_func.add(place, aggregate_arguments, j, arena.get());
agg_func.insertResultInto(place, res_col);
if (!res_col_aggregate_function)
agg_func.insertResultInto(place, res_col);
else
res_col_aggregate_function->insertFrom(place);
}
catch (...)
{

View File

@ -0,0 +1,25 @@
0 200
1 100
0 200 nan
1 100 nan
0 200 nan
1 100 nan
2 200 101
0 200 nan ['---']
1 100 nan ['---']
2 200 101 ['---']
0 200 nan ['---']
1 100 nan ['---']
2 200 101 ['---']
3 200 102 ['igua']
0 200 nan ['---']
1 100 nan ['---']
2 200 101 ['---']
3 200 102 ['igua']
---

View File

@ -0,0 +1,33 @@
DROP TABLE IF EXISTS test.agg_func_col;
CREATE TABLE test.agg_func_col (p Date, k UInt8, d AggregateFunction(sum, UInt64) DEFAULT arrayReduce('sumState', [toUInt64(200)])) ENGINE = AggregatingMergeTree(p, k, 1);
INSERT INTO test.agg_func_col (k) VALUES (0);
INSERT INTO test.agg_func_col SELECT 1 AS k, arrayReduce('sumState', [toUInt64(100)]) AS d;
SELECT k, sumMerge(d) FROM test.agg_func_col GROUP BY k ORDER BY k;
SELECT '';
ALTER TABLE test.agg_func_col ADD COLUMN af_avg1 AggregateFunction(avg, UInt8);
SELECT k, sumMerge(d), avgMerge(af_avg1) FROM test.agg_func_col GROUP BY k ORDER BY k;
SELECT '';
--INSERT INTO test.agg_func_col (k, af_avg1) VALUES (2, arrayReduce('avgState', [101]));
INSERT INTO test.agg_func_col SELECT 2 AS k, arrayReduce('avgState', [101]) AS af_avg1;
SELECT k, sumMerge(d), avgMerge(af_avg1) FROM test.agg_func_col GROUP BY k ORDER BY k;
SELECT '';
ALTER TABLE test.agg_func_col ADD COLUMN af_gua AggregateFunction(groupUniqArray, String) DEFAULT arrayReduce('groupUniqArrayState', ['---', '---']);
SELECT k, sumMerge(d), avgMerge(af_avg1), groupUniqArrayMerge(af_gua) FROM test.agg_func_col GROUP BY k ORDER BY k;
SELECT '';
INSERT INTO test.agg_func_col SELECT 3 AS k, arrayReduce('avgState', [102, 102]) AS af_avg1, arrayReduce('groupUniqArrayState', ['igua', 'igua']) AS af_gua;
SELECT k, sumMerge(d), avgMerge(af_avg1), groupUniqArrayMerge(af_gua) FROM test.agg_func_col GROUP BY k ORDER BY k;
OPTIMIZE TABLE test.agg_func_col;
SELECT '';
SELECT k, sumMerge(d), avgMerge(af_avg1), groupUniqArrayMerge(af_gua) FROM test.agg_func_col GROUP BY k ORDER BY k;
DROP TABLE IF EXISTS test.agg_func_col;
SELECT '';
SELECT arrayReduce('groupUniqArrayIf', [CAST('---' AS Nullable(String)), CAST('---' AS Nullable(String))], [1, 1])[1];