!1194 db4ai特性合入

Merge pull request !1194 from nwen/master
This commit is contained in:
opengauss-bot 2021-08-05 12:15:05 +00:00 committed by Gitee
commit b2cc4b3ee9
110 changed files with 15344 additions and 25 deletions

View File

@ -0,0 +1,128 @@
<!--
doc/src/sgml/ref/alter_operator.sgml
PostgreSQL documentation
-->
<refentry id="SQL-ALTEROPERATOR">
<refmeta>
<refentrytitle>ALTER OPERATOR</refentrytitle>
<manvolnum>7</manvolnum>
<refmiscinfo>SQL - Language Statements</refmiscinfo>
</refmeta>
<refnamediv>
<refname>ALTER OPERATOR</refname>
<refpurpose>change the definition of an operator</refpurpose>
</refnamediv>
<indexterm zone="sql-alteroperator">
<primary>ALTER OPERATOR</primary>
</indexterm>
<refsynopsisdiv>
<synopsis>
ALTER OPERATOR <replaceable>name</replaceable> ( { <replaceable>left_type</replaceable> | NONE } , { <replaceable>right_type</replaceable> | NONE } ) OWNER TO <replaceable>new_owner</replaceable>
ALTER OPERATOR <replaceable>name</replaceable> ( { <replaceable>left_type</replaceable> | NONE } , { <replaceable>right_type</replaceable> | NONE } ) SET SCHEMA <replaceable>new_schema</replaceable>
</synopsis>
</refsynopsisdiv>
<refsect1>
<title>Description</title>
<para>
<command>ALTER OPERATOR</command> changes the definition of
an operator. The only currently available functionality is to change the
owner of the operator.
</para>
<para>
You must own the operator to use <command>ALTER OPERATOR</>.
To alter the owner, you must also be a direct or indirect member of the new
owning role, and that role must have <literal>CREATE</literal> privilege on
the operator's schema. (These restrictions enforce that altering the owner
doesn't do anything you couldn't do by dropping and recreating the operator.
However, a superuser can alter ownership of any operator anyway.)
</para>
</refsect1>
<refsect1>
<title>Parameters</title>
<variablelist>
<varlistentry>
<term><replaceable class="parameter">name</replaceable></term>
<listitem>
<para>
The name (optionally schema-qualified) of an existing operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">left_type</replaceable></term>
<listitem>
<para>
The data type of the operator's left operand; write
<literal>NONE</literal> if the operator has no left operand.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">right_type</replaceable></term>
<listitem>
<para>
The data type of the operator's right operand; write
<literal>NONE</literal> if the operator has no right operand.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">new_owner</replaceable></term>
<listitem>
<para>
The new owner of the operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">new_schema</replaceable></term>
<listitem>
<para>
The new schema for the operator.
</para>
</listitem>
</varlistentry>
</variablelist>
</refsect1>
<refsect1>
<title>Examples</title>
<para>
Change the owner of a custom operator <literal>a @@ b</literal> for type <type>text</type>:
<programlisting>
ALTER OPERATOR @@ (text, text) OWNER TO joe;
</programlisting></para>
</refsect1>
<refsect1>
<title>Compatibility</title>
<para>
There is no <command>ALTER OPERATOR</command> statement in
the SQL standard.
</para>
</refsect1>
<refsect1>
<title>See Also</title>
<simplelist type="inline">
<member><xref linkend="sql-createoperator"></member>
<member><xref linkend="sql-dropoperator"></member>
</simplelist>
</refsect1>
</refentry>

View File

@ -0,0 +1,296 @@
<!--
doc/src/sgml/ref/create_operator.sgml
PostgreSQL documentation
-->
<refentry id="SQL-CREATEOPERATOR">
<refmeta>
<refentrytitle>CREATE OPERATOR</refentrytitle>
<manvolnum>7</manvolnum>
<refmiscinfo>SQL - Language Statements</refmiscinfo>
</refmeta>
<refnamediv>
<refname>CREATE OPERATOR</refname>
<refpurpose>define a new operator</refpurpose>
</refnamediv>
<indexterm zone="sql-createoperator">
<primary>CREATE OPERATOR</primary>
</indexterm>
<refsynopsisdiv>
<synopsis>
CREATE OPERATOR <replaceable>name</replaceable> (
PROCEDURE = <replaceable class="parameter">function_name</replaceable>
[, LEFTARG = <replaceable class="parameter">left_type</replaceable> ] [, RIGHTARG = <replaceable class="parameter">right_type</replaceable> ]
[, COMMUTATOR = <replaceable class="parameter">com_op</replaceable> ] [, NEGATOR = <replaceable class="parameter">neg_op</replaceable> ]
[, RESTRICT = <replaceable class="parameter">res_proc</replaceable> ] [, JOIN = <replaceable class="parameter">join_proc</replaceable> ]
[, HASHES ] [, MERGES ]
)
</synopsis>
</refsynopsisdiv>
<refsect1>
<title>Description</title>
<para>
<command>CREATE OPERATOR</command> defines a new operator,
<replaceable class="parameter">name</replaceable>. The user who
defines an operator becomes its owner. If a schema name is given
then the operator is created in the specified schema. Otherwise it
is created in the current schema.
</para>
<para>
The operator name is a sequence of up to <symbol>NAMEDATALEN</>-1
(63 by default) characters from the following list:
<literallayout>
+ - * / &lt; &gt; = ~ ! @ # % ^ &amp; | ` ?
</literallayout>
There are a few restrictions on your choice of name:
<itemizedlist>
<listitem>
<para><literal>--</literal> and <literal>/*</literal> cannot appear anywhere in an operator name,
since they will be taken as the start of a comment.
</para>
</listitem>
<listitem>
<para>
A multicharacter operator name cannot end in <literal>+</literal> or
<literal>-</literal>,
unless the name also contains at least one of these characters:
<literallayout>
~ ! @ # % ^ &amp; | ` ?
</literallayout>
For example, <literal>@-</literal> is an allowed operator name,
but <literal>*-</literal> is not.
This restriction allows <productname>PostgreSQL</productname> to
parse SQL-compliant commands without requiring spaces between tokens.
</para>
</listitem>
<listitem>
<para>
The use of <literal>=&gt;</> as an operator name is deprecated. It may
be disallowed altogether in a future release.
</para>
</listitem>
</itemizedlist>
</para>
<para>
The operator <literal>!=</literal> is mapped to
<literal>&lt;&gt;</literal> on input, so these two names are always
equivalent.
</para>
<para>
At least one of <literal>LEFTARG</> and <literal>RIGHTARG</> must be defined. For
binary operators, both must be defined. For right unary
operators, only <literal>LEFTARG</> should be defined, while for left
unary operators only <literal>RIGHTARG</> should be defined.
</para>
<para>
The <replaceable class="parameter">function_name</replaceable>
procedure must have been previously defined using <command>CREATE
FUNCTION</command> and must be defined to accept the correct number
of arguments (either one or two) of the indicated types.
</para>
<para>
The other clauses specify optional operator optimization clauses.
Their meaning is detailed in <xref linkend="xoper-optimization">.
</para>
<para>
To be able to create an operator, you must have <literal>USAGE</literal>
privilege on the argument types and the return type, as well
as <literal>EXECUTE</literal> privilege on the underlying function. If a
commutator or negator operator is specified, you must own these operators.
</para>
</refsect1>
<refsect1>
<title>Parameters</title>
<variablelist>
<varlistentry>
<term><replaceable class="parameter">name</replaceable></term>
<listitem>
<para>
The name of the operator to be defined. See above for allowable
characters. The name can be schema-qualified, for example
<literal>CREATE OPERATOR myschema.+ (...)</>. If not, then
the operator is created in the current schema. Two operators
in the same schema can have the same name if they operate on
different data types. This is called
<firstterm>overloading</>.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">function_name</replaceable></term>
<listitem>
<para>
The function used to implement this operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">left_type</replaceable></term>
<listitem>
<para>
The data type of the operator's left operand, if any.
This option would be omitted for a left-unary operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">right_type</replaceable></term>
<listitem>
<para>
The data type of the operator's right operand, if any.
This option would be omitted for a right-unary operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">com_op</replaceable></term>
<listitem>
<para>
The commutator of this operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">neg_op</replaceable></term>
<listitem>
<para>
The negator of this operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">res_proc</replaceable></term>
<listitem>
<para>
The restriction selectivity estimator function for this operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">join_proc</replaceable></term>
<listitem>
<para>
The join selectivity estimator function for this operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><literal>HASHES</literal></term>
<listitem>
<para>
Indicates this operator can support a hash join.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><literal>MERGES</literal></term>
<listitem>
<para>
Indicates this operator can support a merge join.
</para>
</listitem>
</varlistentry>
</variablelist>
<para>
To give a schema-qualified operator name in <replaceable
class="parameter">com_op</replaceable> or the other optional
arguments, use the <literal>OPERATOR()</> syntax, for example:
<programlisting>
COMMUTATOR = OPERATOR(myschema.===) ,
</programlisting></para>
</refsect1>
<refsect1>
<title>Notes</title>
<para>
Refer to <xref linkend="xoper"> for further information.
</para>
<para>
It is not possible to specify an operator's lexical precedence in
<command>CREATE OPERATOR</>, because the parser's precedence behavior
is hard-wired. See <xref linkend="sql-precedence"> for precedence details.
</para>
<para>
The obsolete options <literal>SORT1</>, <literal>SORT2</>,
<literal>LTCMP</>, and <literal>GTCMP</> were formerly used to
specify the names of sort operators associated with a merge-joinable
operator. This is no longer necessary, since information about
associated operators is found by looking at B-tree operator families
instead. If one of these options is given, it is ignored except
for implicitly setting <literal>MERGES</> true.
</para>
<para>
Use <xref linkend="sql-dropoperator"> to delete user-defined operators
from a database. Use <xref linkend="sql-alteroperator"> to modify operators in a
database.
</para>
</refsect1>
<refsect1>
<title>Examples</title>
<para>
The following command defines a new operator, area-equality, for
the data type <type>box</type>:
<programlisting>
CREATE OPERATOR === (
LEFTARG = box,
RIGHTARG = box,
PROCEDURE = area_equal_procedure,
COMMUTATOR = ===,
NEGATOR = !==,
RESTRICT = area_restriction_procedure,
JOIN = area_join_procedure,
HASHES, MERGES
);
</programlisting></para>
</refsect1>
<refsect1>
<title>Compatibility</title>
<para>
<command>CREATE OPERATOR</command> is a
<productname>PostgreSQL</productname> extension. There are no
provisions for user-defined operators in the SQL standard.
</para>
</refsect1>
<refsect1>
<title>See Also</title>
<simplelist type="inline">
<member><xref linkend="sql-alteroperator"></member>
<member><xref linkend="sql-createopclass"></member>
<member><xref linkend="sql-dropoperator"></member>
</simplelist>
</refsect1>
</refentry>

View File

@ -0,0 +1,146 @@
<!--
doc/src/sgml/ref/drop_operator.sgml
PostgreSQL documentation
-->
<refentry id="SQL-DROPOPERATOR">
<refmeta>
<refentrytitle>DROP OPERATOR</refentrytitle>
<manvolnum>7</manvolnum>
<refmiscinfo>SQL - Language Statements</refmiscinfo>
</refmeta>
<refnamediv>
<refname>DROP OPERATOR</refname>
<refpurpose>remove an operator</refpurpose>
</refnamediv>
<indexterm zone="sql-dropoperator">
<primary>DROP OPERATOR</primary>
</indexterm>
<refsynopsisdiv>
<synopsis>
DROP OPERATOR [ IF EXISTS ] <replaceable class="PARAMETER">name</replaceable> ( { <replaceable class="PARAMETER">left_type</replaceable> | NONE } , { <replaceable class="PARAMETER">right_type</replaceable> | NONE } ) [ CASCADE | RESTRICT ]
</synopsis>
</refsynopsisdiv>
<refsect1>
<title>Description</title>
<para>
<command>DROP OPERATOR</command> drops an existing operator from
the database system. To execute this command you must be the owner
of the operator.
</para>
</refsect1>
<refsect1>
<title>Parameters</title>
<variablelist>
<varlistentry>
<term><literal>IF EXISTS</literal></term>
<listitem>
<para>
Do not throw an error if the operator does not exist. A notice is issued
in this case.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">name</replaceable></term>
<listitem>
<para>
The name (optionally schema-qualified) of an existing operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">left_type</replaceable></term>
<listitem>
<para>
The data type of the operator's left operand; write
<literal>NONE</literal> if the operator has no left operand.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="parameter">right_type</replaceable></term>
<listitem>
<para>
The data type of the operator's right operand; write
<literal>NONE</literal> if the operator has no right operand.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><literal>CASCADE</literal></term>
<listitem>
<para>
Automatically drop objects that depend on the operator.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><literal>RESTRICT</literal></term>
<listitem>
<para>
Refuse to drop the operator if any objects depend on it. This
is the default.
</para>
</listitem>
</varlistentry>
</variablelist>
</refsect1>
<refsect1>
<title>Examples</title>
<para>
Remove the power operator <literal>a^b</literal> for type <type>integer</type>:
<programlisting>
DROP OPERATOR ^ (integer, integer);
</programlisting>
</para>
<para>
Remove the left unary bitwise complement operator
<literal>~b</literal> for type <type>bit</type>:
<programlisting>
DROP OPERATOR ~ (none, bit);
</programlisting>
</para>
<para>
Remove the right unary factorial operator <literal>x!</literal>
for type <type>bigint</type>:
<programlisting>
DROP OPERATOR ! (bigint, none);
</programlisting></para>
</refsect1>
<refsect1>
<title>Compatibility</title>
<para>
There is no <command>DROP OPERATOR</command> statement in the SQL standard.
</para>
</refsect1>
<refsect1>
<title>See Also</title>
<simplelist type="inline">
<member><xref linkend="sql-createoperator"></member>
<member><xref linkend="sql-alteroperator"></member>
</simplelist>
</refsect1>
</refentry>

View File

@ -57,7 +57,7 @@ POSTGRES_BKI_SRCS = $(addprefix $(top_srcdir)/src/include/catalog/,\
pg_job.h gs_asp.h pg_job_proc.h pg_extension_data_source.h pg_statistic_ext.h pg_object.h pg_synonym.h \
toasting.h indexing.h gs_obsscaninfo.h pg_directory.h pg_hashbucket.h gs_global_config.h\
pg_streaming_stream.h pg_streaming_cont_query.h pg_streaming_reaper_status.h gs_matview.h\
gs_matview_dependency.h pgxc_slice.h gs_opt_model.h\
gs_matview_dependency.h pgxc_slice.h gs_opt_model.h gs_model.h\
)
# location of Catalog.pm

View File

@ -2088,6 +2088,34 @@
"datetimetz_pl", 1,
AddBuiltinFunc(_0(1297), _1("datetimetz_pl"), _2(2), _3(true), _4(false), _5(datetimetz_timestamptz), _6(1184), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(0), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, 1082, 1266), _21(NULL), _22(NULL), _23(NULL), _24(NULL), _25("datetimetz_timestamptz"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(NULL), _32(false), _33(NULL), _34('f'))
),
AddFuncGroup(
"db4ai_predict_by_bool", 1,
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_BOOL_OID), _1("db4ai_predict_by_bool"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_bool), _6(BOOLOID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_bool"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
),
AddFuncGroup(
"db4ai_predict_by_float4", 1,
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_FLOAT4_OID), _1("db4ai_predict_by_float4"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_float4), _6(FLOAT4OID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_float4"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
),
AddFuncGroup(
"db4ai_predict_by_float8", 1,
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_FLOAT8_OID), _1("db4ai_predict_by_float8"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_float8), _6(FLOAT8OID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_float8"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
),
AddFuncGroup(
"db4ai_predict_by_int32", 1,
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_INT32_OID), _1("db4ai_predict_by_int32"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_int32), _6(INT4OID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_int32"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
),
AddFuncGroup(
"db4ai_predict_by_int64", 1,
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_INT64_OID), _1("db4ai_predict_by_int64"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_int64), _6(INT8OID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_int64"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
),
AddFuncGroup(
"db4ai_predict_by_numeric", 1,
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_NUMERIC_OID), _1("db4ai_predict_by_numeric"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_numeric), _6(NUMERICOID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_numeric"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
),
AddFuncGroup(
"db4ai_predict_by_text", 1,
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_TEXT_OID), _1("db4ai_predict_by_text"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_text), _6(TEXTOID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_text"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
),
AddFuncGroup(
"dcbrt", 1,
AddBuiltinFunc(_0(231), _1("dcbrt"), _2(1), _3(true), _4(false), _5(dcbrt), _6(701), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(0), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(1, 701), _21(NULL), _22(NULL), _23(NULL), _24(NULL), _25("dcbrt"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(NULL), _32(false), _33(NULL), _34('f'))

View File

@ -21,6 +21,7 @@
#include "access/xact.h"
#include "catalog/dependency.h"
#include "catalog/gs_matview.h"
#include "catalog/gs_model.h"
#include "catalog/heap.h"
#include "catalog/index.h"
#include "catalog/namespace.h"
@ -1421,7 +1422,9 @@ static void doDeletion(const ObjectAddress* object, int flags)
case OCLASS_SYNONYM:
RemoveSynonymById(object->objectId);
break;
case OCLASS_DB4AI_MODEL:
remove_model_by_oid(object->objectId);
break;
default:
ereport(ERROR,
(errcode(ERRCODE_UNRECOGNIZED_NODE_TYPE), errmsg("unrecognized object class: %u", object->classId)));
@ -2371,6 +2374,8 @@ ObjectClass getObjectClass(const ObjectAddress* object)
return OCLASS_GLOBAL_SETTING_ARGS;
case ClientLogicColumnSettingsArgsId:
return OCLASS_COLUMN_SETTING_ARGS;
case ModelRelationId:
return OCLASS_DB4AI_MODEL;
default:
break;
}

View File

@ -19,6 +19,7 @@
#include "access/sysattr.h"
#include "catalog/catalog.h"
#include "catalog/indexing.h"
#include "catalog/gs_model.h"
#include "catalog/objectaddress.h"
#include "catalog/pg_authid.h"
#include "catalog/pg_cast.h"
@ -113,6 +114,7 @@ static THR_LOCAL const ObjectPropertyType ObjectProperty[] = {{CastRelationId, C
CLAOID,
Anum_pg_opclass_opcnamespace,
},
{ModelRelationId, GsModelOidIndexId, DB4AI_MODEL, InvalidAttrNumber},
{OperatorRelationId, OperatorOidIndexId, OPEROID, Anum_pg_operator_oprnamespace},
{OperatorFamilyRelationId, OpfamilyOidIndexId, OPFAMILYOID, Anum_pg_opfamily_opfnamespace},
{AuthIdRelationId, AuthIdOidIndexId, AUTHOID, InvalidAttrNumber},
@ -204,6 +206,7 @@ ObjectAddress get_object_address(
address = get_object_address_relobject(objtype, objname, &relation, missing_ok);
break;
case OBJECT_DATABASE:
case OBJECT_DB4AI_MODEL:
case OBJECT_EXTENSION:
case OBJECT_TABLESPACE:
case OBJECT_ROLE:
@ -394,6 +397,9 @@ static ObjectAddress get_object_address_unqualified(ObjectType objtype, List* qu
case OBJECT_DATABASE:
msg = gettext_noop("database name cannot be qualified");
break;
case OBJECT_DB4AI_MODEL:
msg = gettext_noop("model name cannot be qualified");
break;
case OBJECT_EXTENSION:
msg = gettext_noop("extension name cannot be qualified");
break;
@ -440,6 +446,11 @@ static ObjectAddress get_object_address_unqualified(ObjectType objtype, List* qu
address.objectId = get_database_oid(name, missing_ok);
address.objectSubId = 0;
break;
case OBJECT_DB4AI_MODEL:
address.classId = ModelRelationId;
address.objectId = get_model_oid(name, missing_ok);
address.objectSubId = 0;
break;
case OBJECT_EXTENSION:
address.classId = ExtensionRelationId;
address.objectId = get_extension_oid(name, missing_ok);
@ -816,6 +827,7 @@ void check_object_ownership(
if (!pg_type_ownercheck(address.objectId, roleid))
aclcheck_error_type(ACLCHECK_NO_PRIV, address.objectId);
break;
case OBJECT_DB4AI_MODEL:
case OBJECT_DOMAIN:
case OBJECT_ATTRIBUTE:
if (!pg_type_ownercheck(address.objectId, roleid))

View File

@ -27,6 +27,7 @@
#include "catalog/pg_authid.h"
#include "catalog/pg_namespace.h"
#include "catalog/pg_proc.h"
#include "db4ai/predict_by.h"
#include "access/transam.h"
#include "utils/fmgroids.h"
#include "../utils/pg_builtin_proc.h"

View File

@ -5776,6 +5776,20 @@ static DropSynonymStmt* _copyDropSynonymStmt(const DropSynonymStmt* from)
return newnode;
}
// DB4AI
static CreateModelStmt* _copyCreateModelStmt(const CreateModelStmt* from){
CreateModelStmt* newnode = makeNode(CreateModelStmt);
COPY_STRING_FIELD(model);
COPY_STRING_FIELD(architecture);
COPY_NODE_FIELD(hyperparameters);
COPY_NODE_FIELD(select_query);
COPY_NODE_FIELD(model_features);
COPY_NODE_FIELD(model_target);
COPY_SCALAR_FIELD(algorithm);
return newnode;
}
/* ****************************************************************
* pg_list.h copy functions
* ****************************************************************
@ -7079,6 +7093,9 @@ void* copyObject(const void* from)
case T_DropSynonymStmt:
retval = _copyDropSynonymStmt((DropSynonymStmt*)from);
break;
case T_CreateModelStmt: // DB4AI
retval = _copyCreateModelStmt((CreateModelStmt*) from);
break;
case T_A_Expr:
retval = _copyAExpr((A_Expr*)from);
break;

View File

@ -231,6 +231,9 @@ Oid exprType(const Node* expr)
case T_Rownum:
type = INT8OID;
break;
case T_GradientDescentExpr:
type = ((const GradientDescentExpr*)expr)->fieldtype;
break;
default:
ereport(ERROR,
(errcode(ERRCODE_UNRECOGNIZED_NODE_TYPE), errmsg("unrecognized node type: %d", (int)nodeTag(expr))));
@ -840,6 +843,9 @@ Oid exprCollation(const Node* expr)
case T_PlaceHolderVar:
coll = exprCollation((Node*)((const PlaceHolderVar*)expr)->phexpr);
break;
case T_GradientDescentExpr:
coll = InvalidOid;
break;
default:
ereport(
ERROR, (errcode(ERRCODE_DATATYPE_MISMATCH), errmsg("unrecognized node type: %d", (int)nodeTag(expr))));
@ -1549,6 +1555,7 @@ bool expression_tree_walker(Node* node, bool (*walker)(), void* context)
case T_Null:
case T_PgFdwRemoteInfo:
case T_Rownum:
case T_GradientDescentExpr:
/* primitive node types with no expression subnodes */
break;
case T_Aggref: {

View File

@ -570,7 +570,17 @@ static const TagStr g_tagStrArr[] = {{T_Invalid, "Invalid"},
{T_SkewRelInfo, "SkewRelInfo"},
{T_SkewColumnInfo, "SkewColumnInfo"},
{T_SkewValueInfo, "SkewValueInfo"},
{T_QualSkewInfo, "QualSkewInfo"}};
{T_QualSkewInfo, "QualSkewInfo"},
// DB4AI
{T_CreateModelStmt, "CreateModelStmt"},
{T_PredictByFunction, "PredictByFunction"},
{T_GradientDescent, "GradientDescent"},
{T_GradientDescentState, "GradientDescentState"},
{T_KMeans, "Kmeans"},
{T_KMeansState, "KmeansState"},
{T_GradientDescentExpr, "GradientDescentExpr"},
{T_GradientDescentExprState, "GradientDescentExprState"},
};
char* nodeTagToString(NodeTag tag)
{

View File

@ -48,6 +48,7 @@
#include "pgxc/pgxc.h"
#include "pgxc/pgFdwRemote.h"
#endif
#include "db4ai/gd.h"
/*
* Macros to simplify output of different kinds of fields. Use these
@ -5129,6 +5130,33 @@ static void _outIndexVar(StringInfo str, IndexVar* node)
WRITE_BOOL_FIELD(indexpath);
}
static void _outGradientDescent(StringInfo str, GradientDescent* node)
{
WRITE_NODE_TYPE("SGD");
_outPlanInfo(str, (Plan*)node);
appendStringInfoString(str, " :algorithm ");
appendStringInfoString(str, gd_get_algorithm(node->algorithm)->name);
appendStringInfoString(str, " :optimizer ");
appendStringInfoString(str, gd_get_optimizer_name(node->optimizer));
WRITE_INT_FIELD(targetcol);
WRITE_INT_FIELD(max_iterations);
WRITE_INT_FIELD(max_seconds);
WRITE_INT_FIELD(batch_size);
WRITE_BOOL_FIELD(verbose);
WRITE_FLOAT_FIELD(learning_rate, "%.16g");
WRITE_FLOAT_FIELD(decay, "%.16g");
WRITE_FLOAT_FIELD(tolerance, "%.16g");
WRITE_INT_FIELD(seed);
WRITE_FLOAT_FIELD(lambda, "%.16g");
}
static void _outGradientDescentExpr(StringInfo str, GradientDescentExpr* node)
{
WRITE_NODE_TYPE("GradientDescentExpr");
WRITE_UINT_FIELD(field);
WRITE_OID_FIELD(fieldtype);
}
/*
* _outNode -
* converts a Node into ascii string and append it to 'str'
@ -5945,6 +5973,11 @@ static void _outNode(StringInfo str, const void* obj)
break;
case T_RewriteHint:
_outRewriteHint(str, (RewriteHint *)obj);
case T_GradientDescent:
_outGradientDescent(str, (GradientDescent*)obj);
break;
case T_GradientDescentExpr:
_outGradientDescentExpr(str, (GradientDescentExpr*)obj);
break;
default:

View File

@ -80,6 +80,10 @@
#include "instruments/instr_unique_sql.h"
#include "streaming/init.h"
#include "db4ai/aifuncs.h"
#include "db4ai/create_model.h"
#include "db4ai/hyperparameter_validation.h"
#ifndef ENABLE_MULTIPLE_NODES
#include "optimizer/clauses.h"
#endif
@ -274,6 +278,77 @@ Query* transformTopLevelStmt(ParseState* pstate, Node* parseTree, bool isFirstNo
return transformStmt(pstate, parseTree, isFirstNode, isCreateView);
}
Query* transformCreateModelStmt(ParseState* pstate, CreateModelStmt* stmt)
{
SelectStmt* select_stmt = (SelectStmt*) stmt->select_query;
stmt->algorithm = get_algorithm_ml(stmt->architecture);
if (stmt->algorithm == INVALID_ALGORITHM_ML) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Non recognized ML model architecture definition %s", stmt->architecture)));
}
if (SearchSysCacheExists1(DB4AI_MODEL, CStringGetDatum(stmt->model))) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("The model name \"%s\" already exists in gs_model_warehouse.", stmt->model)));
}
// Create the projection for the AI operator in the query plan
// If the algorithm is supervised, the target is always the first element of the list
if (is_supervised(stmt->algorithm)) {
if (list_length(stmt->model_features) == 0) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Supervised ML algorithms require FEATURES clause")));
}else if (list_length(stmt->model_target) == 0) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Supervised ML algorithms require TARGET clause")));
}else if (list_length(stmt->model_target) > 1) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Target clause only supports one expression")));
}
}else{
if (list_length(stmt->model_target) > 0) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Unsupervised ML algorithms cannot have TARGET clause")));
}
}
select_stmt->targetList = NULL;
foreach_cell(it, stmt->model_target) {
select_stmt->targetList = lappend(select_stmt->targetList, lfirst(it));
}
if (list_length(stmt->model_features) > 0) { // User given projection
foreach_cell(it, stmt->model_features) {
select_stmt->targetList = lappend(select_stmt->targetList, lfirst(it));
}
} else { // No projection
ResTarget *rt = makeNode(ResTarget);
ColumnRef *cr = makeNode(ColumnRef);
cr->fields = list_make1(makeNode(A_Star));
cr->location = -1;
rt->name = NULL;
rt->indirection = NIL;
rt->val = (Node *)cr;
rt->location = -1;
select_stmt->targetList = lappend(select_stmt->targetList, rt);
}
// Transform the select query that we prepared for the training operator
Query* select_query = transformStmt(pstate, (Node*) select_stmt);
stmt->select_query = (Node*) select_query;
/* represent the command as a utility Query */
Query* result = makeNode(Query);
result->commandType = CMD_UTILITY;
result->utilityStmt = (Node*)stmt;
return result;
}
/*
* transformStmt -
* recursively transform a Parse tree into a Query tree.
@ -333,6 +408,11 @@ Query* transformStmt(ParseState* pstate, Node* parseTree, bool isFirstNode, bool
result = transformCreateTableAsStmt(pstate, (CreateTableAsStmt*)parseTree);
break;
case T_CreateModelStmt:
result = transformCreateModelStmt(pstate, (CreateModelStmt*) parseTree);
break;
default:
/*

View File

@ -214,6 +214,7 @@ bool IsValidGroupname(const char *input);
static bool checkNlssortArgs(const char *argname);
static void ParseUpdateMultiSet(List *set_target_list, SelectStmt *stmt, core_yyscan_t yyscanner);
static void parameter_check_execute_direct(const char* query);
%}
%pure-parser
@ -316,6 +317,13 @@ static void parameter_check_execute_direct(const char* query);
MergeStmt CreateMatViewStmt RefreshMatViewStmt
CreateWeakPasswordDictionaryStmt DropWeakPasswordDictionaryStmt
/* <DB4AI> */
/* TRAIN MODEL */
%type <node> CreateModelStmt hyperparameter_name_value DropModelStmt
%type <list> features_clause hyperparameter_name_value_list target_clause with_hyperparameters_clause
/* </DB4AI> */
%type <node> select_no_parens select_with_parens select_clause
simple_select values_clause
@ -687,7 +695,7 @@ static void parameter_check_execute_direct(const char* query);
* DOT_DOT is unused in the core SQL grammar, and so will always provoke
* parse errors. It is needed by PL/pgsql.
*/
%token <str> IDENT FCONST SCONST BCONST XCONST Op CmpOp COMMENTSTRING
%token <str> IDENT FCONST SCONST BCONST VCONST XCONST Op CmpOp COMMENTSTRING
%token <ival> ICONST PARAM
%token TYPECAST ORA_JOINOP DOT_DOT COLON_EQUALS PARA_EQUALS
@ -701,8 +709,8 @@ static void parameter_check_execute_direct(const char* query);
/* ordinary key words in alphabetical order */
/* PGXC - added DISTRIBUTE, DIRECT, COORDINATOR, CLEAN, NODE, BARRIER, SLICE, DATANODE */
%token <keyword> ABORT_P ABSOLUTE_P ACCESS ACCOUNT ACTION ADD_P ADMIN AFTER
AGGREGATE ALGORITHM ALL ALSO ALTER ALWAYS ANALYSE ANALYZE AND ANY APP ARRAY AS ASC
ASSERTION ASSIGNMENT ASYMMETRIC AT ATTRIBUTE AUDIT AUTHID AUTHORIZATION AUTOEXTEND AUTOMAPPED
AGGREGATE ALGORITHM ALL ALSO ALTER ALWAYS ANALYSE ANALYZE AND ANY APP ARCHIVE ARRAY AS ASC
ASSERTION ASSIGNMENT ASYMMETRIC AT ATTRIBUTE AUDIT AUTHID AUTHORIZATION AUTOEXTEND AUTOMAPPED AUTOML
BACKWARD BARRIER BEFORE BEGIN_NON_ANOYBLOCK BEGIN_P BETWEEN BIGINT BINARY BINARY_DOUBLE BINARY_INTEGER BIT BLOB_P BOGUS
BOOLEAN_P BOTH BUCKETS BY BYTEAWITHOUTORDER BYTEAWITHOUTORDERWITHEQUAL
@ -728,6 +736,7 @@ static void parameter_check_execute_direct(const char* query);
EXTENSION EXTERNAL EXTRACT
FALSE_P FAMILY FAST FENCED FETCH FILEHEADER_P FILL_MISSING_FIELDS FILTER FIRST_P FIXED_P FLOAT_P FOLLOWING FOR FORCE FOREIGN FORMATTER FORWARD
FEATURES // DB4AI
FREEZE FROM FULL FUNCTION FUNCTIONS
GENERATED GLOBAL GLOBAL_FUNCTION GRANT GRANTED GREATEST GROUP_P GROUPING_P
@ -748,6 +757,7 @@ static void parameter_check_execute_direct(const char* query);
LEAST LESS LEFT LEVEL LIKE LIMIT LIST LISTEN LOAD LOCAL LOCALTIME LOCALTIMESTAMP
LOCATION LOCK_P LOG_P LOGGING LOGIN_ANY LOGIN_FAILURE LOGIN_SUCCESS LOGOUT LOOP
MAPPING MASKING MASTER MATCH MATERIALIZED MATCHED MAXEXTENTS MAXSIZE MAXTRANS MAXVALUE MERGE MINUS_P MINUTE_P MINVALUE MINEXTENTS MODE MODIFY_P MONTH_P MOVE MOVEMENT
MODEL // DB4AI
NAME_P NAMES NATIONAL NATURAL NCHAR NEXT NLSSORT NO NOCOMPRESS NOCYCLE NODE NOLOGGING NOMAXVALUE NOMINVALUE NONE
NOT NOTHING NOTIFY NOTNULL NOWAIT NULL_P NULLIF NULLS_P NUMBER_P NUMERIC NUMSTR NVARCHAR2 NVL
@ -756,24 +766,28 @@ static void parameter_check_execute_direct(const char* query);
PACKAGE PARSER PARTIAL PARTITION PARTITIONS PASSING PASSWORD PCTFREE PER_P PERCENT PERFORMANCE PERM PLACING PLAN PLANS POLICY POSITION
/* PGXC_BEGIN */
POOL PRECEDING PRECISION PREFERRED PREFIX PRESERVE PREPARE PREPARED PRIMARY
POOL PRECEDING PRECISION
/* PGXC_END */
PRIVATE PRIOR PRIVILEGES PRIVILEGE PROCEDURAL PROCEDURE PROFILE
PREDICT // DB4AI
/* PGXC_BEGIN */
PREFERRED PREFIX PRESERVE PREPARE PREPARED PRIMARY
/* PGXC_END */
PRIVATE PRIOR PRIVILEGES PRIVILEGE PROCEDURAL PROCEDURE PROFILE PUBLISH PURGE
QUERY QUOTE
RANDOMIZED RANGE RAW READ REAL REASSIGN REBUILD RECHECK RECURSIVE REDISANYVALUE REF REFERENCES REFRESH REINDEX REJECT_P
RANDOMIZED RANGE RATIO RAW READ REAL REASSIGN REBUILD RECHECK RECURSIVE REDISANYVALUE REF REFERENCES REFRESH REINDEX REJECT_P
RELATIVE_P RELEASE RELOPTIONS REMOTE_P REMOVE RENAME REPEATABLE REPLACE REPLICA
RESET RESIZE RESOURCE RESTART RESTRICT RETURN RETURNING RETURNS REUSE REVOKE RIGHT ROLE ROLES ROLLBACK ROLLUP
ROW ROWNUM ROWS RULE
SAVEPOINT SCHEMA SCROLL SEARCH SECOND_P SECURITY SELECT SEQUENCE SEQUENCES
SAMPLE SAVEPOINT SCHEMA SCROLL SEARCH SECOND_P SECURITY SELECT SEQUENCE SEQUENCES
SERIALIZABLE SERVER SESSION SESSION_USER SET SETS SETOF SHARE SHIPPABLE SHOW SHUTDOWN
SIMILAR SIMPLE SIZE SLICE SMALLDATETIME SMALLDATETIME_FORMAT_P SMALLINT SNAPSHOT SOME SOURCE_P SPACE SPILL SPLIT STABLE STANDALONE_P START
STATEMENT STATEMENT_ID STATISTICS STDIN STDOUT STORAGE STORE_P STORED STREAM STRICT_P STRIP_P SUBSTRING
STATEMENT STATEMENT_ID STATISTICS STDIN STDOUT STORAGE STORE_P STORED STRATIFY STREAM STRICT_P STRIP_P SUBSTRING
SYMMETRIC SYNONYM SYSDATE SYSID SYSTEM_P SYS_REFCURSOR
TABLE TABLES TABLESAMPLE TABLESPACE TEMP TEMPLATE TEMPORARY TEXT_P THAN THEN TIME TIME_FORMAT_P TIMESTAMP TIMESTAMP_FORMAT_P TIMESTAMPDIFF TINYINT
TABLE TABLES TABLESAMPLE TABLESPACE TARGET TEMP TEMPLATE TEMPORARY TEXT_P THAN THEN TIME TIME_FORMAT_P TIMESTAMP TIMESTAMP_FORMAT_P TIMESTAMPDIFF TINYINT
TO TRAILING TRANSACTION TREAT TRIGGER TRIM TRUE_P
TRUNCATE TRUSTED TSFIELD TSTAG TSTIME TYPE_P TYPES_P
@ -812,6 +826,7 @@ static void parameter_check_execute_direct(const char* query);
%nonassoc PARTIAL_EMPTY_PREC
%nonassoc CLUSTER
%nonassoc SET /* see relation_expr_opt_alias */
%right FEATURES TARGET // DB4AI
%left UNION EXCEPT MINUS_P
%left INTERSECT
%left OR
@ -852,7 +867,7 @@ static void parameter_check_execute_direct(const char* query);
*/
%nonassoc UNBOUNDED /* ideally should have same precedence as IDENT */
%nonassoc IDENT GENERATED NULL_P PARTITION RANGE ROWS PRECEDING FOLLOWING CUBE ROLLUP
%left Op OPERATOR /* multi-character ops and user-defined operators */
%left Op OPERATOR '@' /* multi-character ops and user-defined operators */
%nonassoc NOTNULL
%nonassoc ISNULL
%nonassoc IS /* sets precedence for IS NULL, etc */
@ -987,6 +1002,7 @@ stmt :
| CreateFunctionStmt
| CreateGroupStmt
| CreateMatViewStmt
| CreateModelStmt // DB4AI
| CreateNodeGroupStmt
| CreateNodeStmt
| CreateOpClassStmt
@ -1033,6 +1049,7 @@ stmt :
| DropFdwStmt
| DropForeignServerStmt
| DropGroupStmt
| DropModelStmt // DB4AI
| DropNodeGroupStmt
| DropNodeStmt
| DropOpClassStmt
@ -7585,6 +7602,133 @@ AlterUserMappingStmt: ALTER USER MAPPING FOR auth_ident SERVER name alter_generi
}
;
/*****************************************************************************
*
* QUERY:
* CREATE MODEL <model_name>
*
*
*****************************************************************************/
CreateModelStmt:
CREATE MODEL ColId
USING ColId
features_clause
target_clause
from_clause
with_hyperparameters_clause
{
CreateModelStmt *n = makeNode(CreateModelStmt);
n->model = pstrdup($3);
n->architecture = pstrdup($5);
n->model_features = $6;
n->model_target = $7;
// The clause will be constructed in tranform
SelectStmt *s = makeNode(SelectStmt);
s->fromClause = $8;
n->select_query = (Node*) s;
n->hyperparameters = $9;
$$ = (Node*) n;
}
;
features_clause:
FEATURES target_list{
List* result = $2;
// Verify that target clause is not '*'
foreach_cell(it, result){
ResTarget* n = (ResTarget*) lfirst(it);
ColumnRef* cr = (n->val != NULL && IsA(n->val, ColumnRef)) ? (ColumnRef*)(n->val) : NULL;
List* l = (cr != NULL) ? cr->fields : NULL;
Node* node = list_length(l) > 0 ? linitial_node(Node, l) : NULL;
if (node != NULL && IsA(node, A_Star)){
elog(ERROR, "FEATURES clause cannot be *");
}
}
$$ = result;
}
| {
List* result = NULL;
$$ = result;
}
target_clause:
TARGET target_list{
List* result = $2;
// Verify that target clause is not '*'
foreach_cell(it, result){
ResTarget* n = (ResTarget*) lfirst(it);
ColumnRef* cr = (n->val != NULL && IsA(n->val, ColumnRef)) ? (ColumnRef*)(n->val) : NULL;
List* l = (cr != NULL) ? cr->fields : NULL;
Node* node = list_length(l) > 0 ? linitial_node(Node, l) : NULL;
if (node != NULL && IsA(node, A_Star)){
elog(ERROR, "TARGET clause cannot be *");
}
}
$$ = result;
}
| {
List* result = NULL;
$$ = result;
}
with_hyperparameters_clause:
WITH hyperparameter_name_value_list { $$ = $2; }
| { $$ = NULL; }
;
hyperparameter_name_value_list:
hyperparameter_name_value { $$ = list_make1($1); }
| hyperparameter_name_value_list ',' hyperparameter_name_value
{
$$ = lappend($1,$3);
}
;
hyperparameter_name_value:
ColLabel '=' var_value
{
VariableSetStmt *n = makeNode(VariableSetStmt);
n->kind = VAR_SET_VALUE;
n->name = $1;
n->args = list_make1($3);
$$ = (Node*) n;
}
| ColLabel '=' DEFAULT
{
VariableSetStmt *n = makeNode(VariableSetStmt);
n->kind = VAR_SET_DEFAULT;
n->name = $1;
n->args = NULL;
$$ = (Node*) n;
}
;
DropModelStmt:
DROP MODEL ColId opt_drop_behavior
{
DropStmt *n = makeNode(DropStmt);
n->removeType = OBJECT_DB4AI_MODEL;
n->objects = list_make1(list_make1(makeString($3)));
n->arguments = NULL;
n->behavior = $4;
n->missing_ok = false;
n->concurrent = false;
n->isProcedure = false;
$$ = (Node *)n;
}
;
/*****************************************************************************
*
* QUERIES For ROW LEVEL SECURITY:
@ -14823,6 +14967,7 @@ ExplainableStmt:
| MergeStmt
| DeclareCursorStmt
| CreateAsStmt
| CreateModelStmt
| ExecuteStmt /* by default all are $$=$1 */
;
@ -17840,6 +17985,15 @@ a_expr: c_expr { $$ = $1; }
list_make1($1), @2),
@2);
}
| PREDICT BY ColId '(' FEATURES func_arg_list ')'
{
PredictByFunction * n = makeNode(PredictByFunction);
n->model_name = $3;
n->model_name_location = @3;
n->model_args = $6;
n->model_args_location = @6;
$$ = (Node*) n;
}
;
/*
@ -20191,6 +20345,7 @@ unreserved_keyword:
| EXTERNAL
| FAMILY
| FAST
| FEATURES // DB4AI
| FILEHEADER_P
| FILL_MISSING_FIELDS
| FILTER
@ -20276,6 +20431,7 @@ unreserved_keyword:
| MINUTE_P
| MINVALUE
| MODE
| MODEL // DB4AI
| MONTH_P
| MOVE
| MOVEMENT
@ -20321,6 +20477,7 @@ unreserved_keyword:
| POLICY
| POOL
| PRECEDING
| PREDICT // DB4AI
/* PGXC_BEGIN */
| PREFERRED
/* PGXC_END */
@ -20338,6 +20495,7 @@ unreserved_keyword:
| QUOTE
| RANDOMIZED
| RANGE
| RATIO
| RAW '(' Iconst ')' { $$ = "raw";}
| RAW %prec UNION { $$ = "raw";}
| READ
@ -20373,6 +20531,7 @@ unreserved_keyword:
| ROLLUP
| ROWS
| RULE
| SAMPLE
| SAVEPOINT
| SCHEMA
| SCROLL
@ -20410,7 +20569,8 @@ unreserved_keyword:
| STORAGE
| STORE_P
| STORED
| STREAM
| STRATIFY
| STREAM
| STRICT_P
| STRIP_P
| SYNONYM
@ -20419,6 +20579,7 @@ unreserved_keyword:
| SYSTEM_P
| TABLES
| TABLESPACE
| TARGET
| TEMP
| TEMPLATE
| TEMPORARY
@ -22198,6 +22359,7 @@ static void parameter_check_execute_direct(const char* query)
errmsg("must be system admin or monitor admin to use EXECUTE DIRECT")));
}
/*
* Must undefine this stuff before including scan.c, since it has different
* definitions for these macros.

View File

@ -20,6 +20,7 @@
#include "catalog/pg_proc.h"
#include "commands/dbcommands.h"
#include "commands/sequence.h"
#include "db4ai/predict_by.h"
#include "foreign/foreign.h"
#include "miscadmin.h"
#include "nodes/makefuncs.h"
@ -67,6 +68,7 @@ static Node* transformXmlExpr(ParseState* pstate, XmlExpr* x);
static Node* transformXmlSerialize(ParseState* pstate, XmlSerialize* xs);
static Node* transformBooleanTest(ParseState* pstate, BooleanTest* b);
static Node* transformCurrentOfExpr(ParseState* pstate, CurrentOfExpr* cexpr);
static Node* transformPredictByFunction(ParseState* pstate, PredictByFunction* cexpr);
static Node* transformColumnRef(ParseState* pstate, ColumnRef* cref);
static Node* transformWholeRowRef(ParseState* pstate, RangeTblEntry* rte, int location);
static Node* transformIndirection(ParseState* pstate, Node* basenode, List* indirection);
@ -286,6 +288,10 @@ Node* transformExpr(ParseState* pstate, Node* expr)
result = transformCurrentOfExpr(pstate, (CurrentOfExpr*)expr);
break;
case T_PredictByFunction:
result = transformPredictByFunction(pstate, (PredictByFunction*) expr);
break;
/*********************************************
* Quietly accept node types that may be presented when we are
* called on an already-transformed tree.
@ -2027,6 +2033,102 @@ static Node* transformCurrentOfExpr(ParseState* pstate, CurrentOfExpr* cexpr)
return (Node*)cexpr;
}
// Locate in the system catalog the information for a model name
static char* select_prediction_function(Model* model){
char* result;
switch(model->return_type){
case BOOLOID:
result = "db4ai_predict_by_bool";
break;
case FLOAT4OID:
result = "db4ai_predict_by_float4";
break;
case FLOAT8OID:
result = "db4ai_predict_by_float8";
break;
case INT1OID:
case INT2OID:
case INT4OID:
result = "db4ai_predict_by_int32";
break;
case INT8OID:
result = "db4ai_predict_by_int64";
break;
case NUMERICOID:
result = "db4ai_predict_by_numeric";
break;
case VARCHAROID:
case BPCHAROID:
case CHAROID:
case TEXTOID:
result = "db4ai_predict_by_text";
break;
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Cannot trigger prediction for model with oid %d", model->return_type)));
result = NULL;
break;
}
return result;
}
// Convert the PredictByFunction created during parsing phase into the function
// call that computes the prediction of the model. We cannot do this during parsing
// because at that moment we do not know the return type of the model, which is obtained
// from the catalog
static Node* transformPredictByFunction(ParseState* pstate, PredictByFunction* p)
{
FuncCall* n = makeNode(FuncCall);
if (p->model_name == NULL) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Model name for prediction cannot be null")));
}
Model* model = get_model(p->model_name, true);
if (model == NULL) {
ereport(ERROR, (errmsg(
"No model found with name %s", p->model_name)));
}
// Locate the proper function according to the model name
char* function_name = select_prediction_function(model);
ereport(DEBUG1, (errmsg(
"Selecting prediction function %s for model %s",
function_name, p->model_name)));
n->funcname = list_make1(makeString(function_name));
n->colname = p->model_name;
// Fill model name parameter
A_Const* model_name_aconst = makeNode(A_Const);
model_name_aconst->val.type = T_String;
model_name_aconst->val.val.str = p->model_name;
model_name_aconst->location = p->model_name_location;
// Copy other parameters
n->args = list_make1(model_name_aconst);
if (list_length(p->model_args) > 0) {
n->args = lappend3(n->args, p->model_args);
}else{
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Innput features for the model not specified")));
}
n->agg_order = NULL;
n->agg_star = FALSE;
n->agg_distinct = FALSE;
n->func_variadic = FALSE;
n->over = NULL;
n->location = p->model_args_location;
n->call_func = false;
return transformExpr(pstate, (Node*)n);
}
/*
* Construct a whole-row reference to represent the notation "relation.*".
*/

View File

@ -1495,6 +1495,14 @@ static int FigureColnameInternal(Node* node, char** name)
return 2;
}
break;
case T_PredictByFunction: {
size_t len = strlen(((PredictByFunction*)node)->model_name) + strlen("_pred") + 1;
char* colname = (char*)palloc0(len);
errno_t rc = snprintf_s(colname, len, len - 1, "%s_pred", ((PredictByFunction*)node)->model_name);
securec_check_ss(rc, "\0", "\0");
*name = colname;
return 1;
} break;
case T_TypeCast:
strength = FigureColnameInternal(((TypeCast*)node)->arg, name);
if (strength <= 1) {

View File

@ -92,6 +92,7 @@
#include "utils/typcache.h"
#include "utils/xml.h"
#include "vecexecutor/vecnodes.h"
#include "db4ai/gd.h"
/* ----------
* Pretty formatting constants
@ -9234,6 +9235,11 @@ static void get_rule_expr(Node* node, deparse_context* context, bool showimplici
}
} break;
case T_GradientDescentExpr: {
GradientDescentExpr* gdnode = (GradientDescentExpr*)node;
appendStringInfo(buf, "GD(%s)", gd_get_expr_name(gdnode->field));
} break;
default:
if (context->qrw_phase)
appendStringInfo(buf, "<unknown %d>", (int)nodeTag(node));

View File

@ -138,6 +138,7 @@
#include "catalog/pg_streaming_stream.h"
#include "catalog/pg_streaming_cont_query.h"
#include "catalog/pg_streaming_reaper_status.h"
#include "catalog/gs_model.h"
#include "commands/matview.h"
#include "commands/sec_rls_cmds.h"
#include "commands/tablespace.h"
@ -292,6 +293,7 @@ static const FormData_pg_attribute Desc_gs_client_global_keys[Natts_gs_client_gl
static const FormData_pg_attribute Desc_gs_client_global_keys_args[Natts_gs_client_global_keys_args] = {Schema_gs_client_global_keys_args};
static const FormData_pg_attribute Desc_gs_opt_model[Natts_gs_opt_model] = {Schema_gs_opt_model};
static const FormData_pg_attribute Desc_gs_model_warehouse[Natts_gs_model_warehouse] = {Schema_gs_model_warehouse};
/* Please add to the array in ascending order of oid value */
static struct CatalogRelationBuildParam catalogBuildParam[CATALOG_NUM] = {{DefaultAclRelationId,
@ -768,6 +770,15 @@ static struct CatalogRelationBuildParam catalogBuildParam[CATALOG_NUM] = {{Defau
Desc_pg_ts_template,
false,
true},
{ModelRelationId,
"gs_model_warehouse",
ModelRelation_Rowtype_Id,
false,
true,
Natts_gs_model_warehouse,
Desc_gs_model_warehouse,
false,
true},
{DataSourceRelationId,
"pg_extension_data_source",
DataSourceRelation_Rowtype_Id,

View File

@ -25,6 +25,7 @@
#include "access/sysattr.h"
#include "catalog/gs_obsscaninfo.h"
#include "catalog/gs_opt_model.h"
#include "catalog/gs_model.h"
#include "catalog/gs_policy_label.h"
#include "catalog/indexing.h"
#include "catalog/pg_aggregate.h"
@ -299,6 +300,16 @@ static const struct cachedesc cacheinfo[] = {{AggregateRelationId, /* AGGFNOID *
1,
{ObjectIdAttributeNumber, 0, 0, 0},
32},
{ModelRelationId, /* DB4AI_MODELOID */
GsModelOidIndexId,
1,
{ObjectIdAttributeNumber, 0, 0, 0},
256},
{ModelRelationId, /* DB4AI_MODEL */
GsModelNameIndexId,
1,
{Anum_gs_model_model_name, 0, 0, 0},
256},
{DefaultAclRelationId, /* DEFACLROLENSPOBJ */
DefaultAclRoleNspObjIndexId,
3,

View File

@ -93,6 +93,7 @@ const module_data module_map[] = {{MOD_ALL, "ALL"},
{MOD_THREAD_POOL, "THREAD_POOL"},
{MOD_OPT_AI, "OPT_AI"},
{MOD_GEN_COL, "GEN_COL"},
{MOD_DB4AI, "DB4AI"},
/* add your module name above */
{MOD_MAX, "BACKEND"}};

View File

@ -519,6 +519,7 @@ extern THR_LOCAL bool stmt_contains_operator_plus;
EXTENSION EXTERNAL EXTRACT
FALSE_P FAMILY FAST FENCED FETCH FILEHEADER_P FILL_MISSING_FIELDS FILTER FIRST_P FIXED_P FLOAT_P FOLLOWING FOR FORCE FOREIGN FORMATTER FORWARD
FEATURES // DB4AI
FREEZE FROM FULL FUNCTION FUNCTIONS
GENERATED GLOBAL GLOBAL_FUNCTION GRANT GRANTED GREATEST GROUP_P GROUPING_P
@ -538,6 +539,7 @@ extern THR_LOCAL bool stmt_contains_operator_plus;
LEAST LESS LEFT LEVEL LIST LIKE LIMIT LISTEN LOAD LOCAL LOCALTIME LOCALTIMESTAMP
LOCATION LOCK_P LOG_P LOGGING LOGIN_ANY LOGIN_SUCCESS LOGIN_FAILURE LOGOUT LOOP
MAPPING MASKING MASTER MASTR MATCH MATERIALIZED MATCHED MAXEXTENTS MAXSIZE MAXTRANS MAXVALUE MERGE MINUS_P MINUTE_P MINVALUE MINEXTENTS MODE MODIFY_P MONTH_P MOVE MOVEMENT
MODEL // DB4AI
NAME_P NAMES NATIONAL NATURAL NCHAR NEXT NLSSORT NO NOCOMPRESS NOCYCLE NODE NOLOGGING NOMAXVALUE NOMINVALUE NONE
NOT NOTHING NOTIFY NOTNULL NOWAIT NULL_P NULLIF NULLS_P NUMBER_P NUMERIC NUMSTR NVARCHAR2 NVL
@ -546,24 +548,28 @@ extern THR_LOCAL bool stmt_contains_operator_plus;
PACKAGE PARSER PARTIAL PARTITION PARTITIONS PASSING PASSWORD PCTFREE PER_P PERCENT PERFORMANCE PERM PLACING PLAN PLANS POLICY POSITION
/* PGXC_BEGIN */
POOL PRECEDING PRECISION PREFERRED PREFIX PRESERVE PREPARE PREPARED PRIMARY
POOL PRECEDING PRECISION
/* PGXC_END */
PREDICT
/* PGXC_BEGIN */
PREFERRED PREFIX PRESERVE PREPARE PREPARED PRIMARY
/* PGXC_END */
PRIVATE PRIOR PRIVILEGES PRIVILEGE PROCEDURAL PROCEDURE PROFILE
QUERY QUOTE
RANDOMIZED RANGE RAW READ REAL REASSIGN REBUILD RECHECK RECURSIVE REDISANYVALUE REF REFERENCES REFRESH REINDEX REJECT_P
RANDOMIZED RANGE RATIO RAW READ REAL REASSIGN REBUILD RECHECK RECURSIVE REDISANYVALUE REF REFERENCES REFRESH REINDEX REJECT_P
RELATIVE_P RELEASE RELOPTIONS REMOTE_P REMOVE RENAME REPEATABLE REPLACE REPLICA
RESET RESIZE RESOURCE RESTART RESTRICT RETURN RETURNING RETURNS REUSE REVOKE RIGHT ROLE ROLES ROLLBACK ROLLUP
ROW ROWNUM ROWS RULE
SAVEPOINT SCHEMA SCROLL SEARCH SECOND_P SECURITY SELECT SEQUENCE SEQUENCES
SAMPLE SAVEPOINT SCHEMA SCROLL SEARCH SECOND_P SECURITY SELECT SEQUENCE SEQUENCES
SERIALIZABLE SERVER SESSION SESSION_USER SET SETS SETOF SHARE SHIPPABLE SHOW SHUTDOWN
SIMILAR SIMPLE SIZE SLICE SMALLDATETIME SMALLDATETIME_FORMAT_P SMALLINT SNAPSHOT SOME SOURCE_P SPACE SPILL SPLIT STABLE STANDALONE_P START
STATEMENT STATEMENT_ID STATISTICS STDIN STDOUT STORAGE STORE_P STORED STREAM STRICT_P STRIP_P SUBSTRING
SYMMETRIC SYNONYM SYSDATE SYSID SYSTEM_P SYS_REFCURSOR
TABLE TABLES TABLESAMPLE TABLESPACE TEMP TEMPLATE TEMPORARY TEXT_P THAN THEN TIME TIME_FORMAT_P TIMESTAMP TIMESTAMP_FORMAT_P TIMESTAMPDIFF TINYINT
TABLE TABLES TABLESAMPLE TABLESPACE TARGET TEMP TEMPLATE TEMPORARY TEXT_P THAN THEN TIME TIME_FORMAT_P TIMESTAMP TIMESTAMP_FORMAT_P TIMESTAMPDIFF TINYINT
TO TRAILING TRANSACTION TREAT TRIGGER TRIM TRUE_P
TRUNCATE TRUSTED TSFIELD TSTAG TSTIME TYPE_P TYPES_P
@ -10714,6 +10720,7 @@ unreserved_keyword:
| EXTERNAL
| FAMILY
| FAST
| FEATURES // DB4AI
| FILEHEADER_P
| FILL_MISSING_FIELDS
| FIRST_P
@ -10796,6 +10803,7 @@ unreserved_keyword:
| MINUTE_P
| MINVALUE
| MODE
| MODEL // DB4AI
| MONTH_P
| MOVE
| MOVEMENT
@ -10838,6 +10846,7 @@ unreserved_keyword:
| PLANS
| POOL
| PRECEDING
| PREDICT // DB4AI
/* PGXC_BEGIN */
| PREFERRED
/* PGXC_END */
@ -10931,6 +10940,7 @@ unreserved_keyword:
| SYSTEM_P
| TABLES
| TABLESPACE
| TARGET // DB4AI
| TEMP
| TEMPLATE
| TEMPORARY

View File

@ -301,7 +301,7 @@ static bool last_pragma;
* Some of these are not directly referenced in this file, but they must be
* here anyway.
*/
%token <str> IDENT FCONST SCONST BCONST XCONST Op CmpOp COMMENTSTRING
%token <str> IDENT FCONST SCONST BCONST VCONST XCONST Op CmpOp COMMENTSTRING
%token <ival> ICONST PARAM
%token TYPECAST ORA_JOINOP DOT_DOT COLON_EQUALS PARA_EQUALS

View File

@ -9,6 +9,6 @@ subdir = src/gausskernel/dbmind
top_builddir = ../../..
include $(top_builddir)/src/Makefile.global
SUBDIRS = kernel
SUBDIRS = kernel db4ai
include $(top_srcdir)/src/gausskernel/common.mk

View File

@ -0,0 +1,20 @@
#This is the main CMAKE for build bin.
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_RULE_MESSAGES OFF)
set(CMAKE_SKIP_RPATH TRUE)
set(CMAKE_MODULE_PATH
${CMAKE_CURRENT_SOURCE_DIR}/catalog
${CMAKE_CURRENT_SOURCE_DIR}/commands
${CMAKE_CURRENT_SOURCE_DIR}/executor
)
add_subdirectory(catalog)
add_subdirectory(commands)
add_subdirectory(executor)
if("${ENABLE_MULTIPLE_NODES}" STREQUAL "OFF")
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/snapshots
DESTINATION share/postgresql/db4ai/ OPTIONAL
)
endif()

View File

@ -0,0 +1,24 @@
#---------------------------------------------------------------------------------------
#
# IDENTIFICATION
# src/gausskernel/dbmind/db4ai/Makefile
#
# ---------------------------------------------------------------------------------------
subdir = src/gausskernel/dbmind/db4ai
top_builddir = ../../../..
include $(top_builddir)/src/Makefile.global
ifneq "$(MAKECMDGOALS)" "clean"
ifneq "$(MAKECMDGOALS)" "distclean"
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
-include $(DEPEND)
endif
endif
endif
SUBDIRS = catalog commands executor
include $(top_srcdir)/src/gausskernel/common.mk

View File

@ -0,0 +1,24 @@
#This is the main CMAKE for build all components.
AUX_SOURCE_DIRECTORY(${CMAKE_CURRENT_SOURCE_DIR} TGT_catalog_SRC)
set(TGT_catalog_INC
${PROJECT_OPENGS_DIR}/contrib/log_fdw
${PROJECT_TRUNK_DIR}/distribute/bin/gds
${PROJECT_SRC_DIR}/include/libcomm
${PROJECT_SRC_DIR}/include
${PROJECT_SRC_DIR}/lib/gstrace
${LZ4_INCLUDE_PATH}
${LIBCGROUP_INCLUDE_PATH}
${LIBORC_INCLUDE_PATH}
${EVENT_INCLUDE_PATH}
${PROTOBUF_INCLUDE_PATH}
${ZLIB_INCLUDE_PATH}
)
set(catalog_DEF_OPTIONS ${MACRO_OPTIONS})
set(catalog_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
list(REMOVE_ITEM catalog_COMPILE_OPTIONS -fPIC)
set(catalog_COMPILE_OPTIONS ${catalog_COMPILE_OPTIONS} -std=c++14 -fPIE)
set(catalog_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
add_static_objtarget(gausskernel_db4ai_catalog TGT_catalog_SRC TGT_catalog_INC "${catalog_DEF_OPTIONS}" "${catalog_COMPILE_OPTIONS}" "${catalog_LINK_OPTIONS}")

View File

@ -0,0 +1,22 @@
#---------------------------------------------------------------------------------------
#
# IDENTIFICATION
# src/gausskernel/dbmind/executor/Makefile
#
# ---------------------------------------------------------------------------------------
subdir = src/gausskernel/dbmind/db4ai/catalog
top_builddir = ../../../../..
include $(top_builddir)/src/Makefile.global
ifneq "$(MAKECMDGOALS)" "clean"
ifneq "$(MAKECMDGOALS)" "distclean"
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
-include $(DEPEND)
endif
endif
endif
OBJS = model_warehouse.o
include $(top_srcdir)/src/gausskernel/common.mk

View File

@ -0,0 +1,760 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* command.h
*
* IDENTIFICATION
* src/gausskernel/catalog/model_warehouse.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/model_warehouse.h"
#include "db4ai/gd.h"
#include "db4ai/aifuncs.h"
#include "access/tableam.h"
#include "catalog/gs_model.h"
#include "catalog/indexing.h"
#include "catalog/pg_proc.h"
#include "instruments/generate_report.h"
#include "lib/stringinfo.h"
#include "utils/fmgroids.h"
#include "utils/builtins.h"
#include "utils/lsyscache.h"
typedef enum ListType {
HYPERPARAMETERS = 0,
COEFS,
SCORES,
} ListType;
template <ListType ltype> void ListToTuple(List *list, Datum *name, Datum *value, Datum *oid);
template <ListType listType> void TupleToList(Model *model, Datum *names, Datum *values, Datum *oids);
template <ListType ltype> static void add_model_parameter(Model *model, const char *name, Oid type, Datum value);
static Datum string_to_datum(const char *str, Oid datatype);
void store_SGD(Datum *values, bool *nulls, ModelGradientDescent *SGDmodel);
void get_SGD(HeapTuple *tuple, ModelGradientDescent *resGD, Form_gs_model_warehouse tuplePointer);
void store_kmeans(Datum *values, bool *nulls, ModelKMeans *kmeansModel);
void get_kmeans(HeapTuple *tuple, ModelKMeans *modelKmeans);
void splitStringFillCentroid(WHCentroid *curseCent, char *strDescribe);
char *splitStringFillCoordinates(WHCentroid *curseCent, char *strCoordinates, int dimension);
// Store the model in the catalog tables
void store_model(const Model *model)
{
HeapTuple tuple;
int rc;
Relation rel = NULL;
Oid extOwner = GetUserId();
Datum values[Natts_gs_model_warehouse];
bool nulls[Natts_gs_model_warehouse];
Datum ListNames, ListValues, ListOids;
if (SearchSysCacheExists1(DB4AI_MODEL, CStringGetDatum(model->model_name))) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("The model name \"%s\" already exists in gs_model_warehouse.", model->model_name)));
}
rel = heap_open(ModelRelationId, RowExclusiveLock);
rc = memset_s(values, sizeof(values), 0, sizeof(values));
securec_check(rc, "\0", "\0");
rc = memset_s(nulls, sizeof(nulls), 0, sizeof(nulls));
securec_check(rc, "\0", "\0");
values[Anum_gs_model_model_name - 1] = DirectFunctionCall1(namein, CStringGetDatum(model->model_name));
values[Anum_gs_model_owner_oid - 1] = ObjectIdGetDatum(extOwner);
values[Anum_gs_model_create_time - 1] = DirectFunctionCall1(timestamptz_timestamp, GetCurrentTimestamp());
values[Anum_gs_model_processedTuples - 1] = Int64GetDatum(model->processed_tuples);
values[Anum_gs_model_discardedTuples - 1] = Int64GetDatum(model->discarded_tuples);
values[Anum_gs_model_process_time_secs - 1] = Float4GetDatum(model->pre_time_secs);
values[Anum_gs_model_exec_time_secs - 1] = Float4GetDatum(model->exec_time_secs);
values[Anum_gs_model_iterations - 1] = Int64GetDatum(model->num_actual_iterations);
values[Anum_gs_model_outputType - 1] = ObjectIdGetDatum(model->return_type);
values[Anum_gs_model_query - 1] = CStringGetTextDatum(model->sql);
if (model->hyperparameters == nullptr) {
nulls[Anum_gs_model_hyperparametersNames - 1] = true;
nulls[Anum_gs_model_hyperparametersValues - 1] = true;
nulls[Anum_gs_model_hyperparametersOids - 1] = true;
} else {
ListToTuple<ListType::HYPERPARAMETERS>(model->hyperparameters, &ListNames, &ListValues, &ListOids);
values[Anum_gs_model_hyperparametersNames - 1] = ListNames;
values[Anum_gs_model_hyperparametersValues - 1] = ListValues;
values[Anum_gs_model_hyperparametersOids - 1] = ListOids;
}
if (model->train_info == nullptr) {
nulls[Anum_gs_model_coefNames - 1] = true;
nulls[Anum_gs_model_coefValues - 1] = true;
nulls[Anum_gs_model_coefOids - 1] = true;
} else {
ListToTuple<ListType::COEFS>(model->train_info, &ListNames, &ListValues, &ListOids);
values[Anum_gs_model_coefNames - 1] = ListNames;
values[Anum_gs_model_coefValues - 1] = ListValues;
values[Anum_gs_model_coefOids - 1] = ListOids;
}
if (model->scores == nullptr) {
nulls[Anum_gs_model_trainingScoresName - 1] = true;
nulls[Anum_gs_model_trainingScoresValue - 1] = true;
} else {
ListToTuple<ListType::SCORES>(model->scores, &ListNames, &ListValues, &ListOids);
values[Anum_gs_model_trainingScoresName - 1] = ListNames;
values[Anum_gs_model_trainingScoresValue - 1] = ListValues;
}
switch (model->algorithm) {
case LOGISTIC_REGRESSION:
case SVM_CLASSIFICATION:
case LINEAR_REGRESSION: {
store_SGD(values, nulls, (ModelGradientDescent *)model);
} break;
case KMEANS: {
store_kmeans(values, nulls, (ModelKMeans *)model);
} break;
default:
// do not cache
ereport(NOTICE, (errmsg("clone model for type %d", (int)model->algorithm)));
break;
}
tuple = heap_form_tuple(rel->rd_att, values, nulls);
(void)simple_heap_insert(rel, tuple);
CatalogUpdateIndexes(rel, tuple);
heap_freetuple_ext(tuple);
heap_close(rel, RowExclusiveLock);
}
// Get the model from the catalog tables
Model *get_model(const char *model_name, bool only_model)
{
void *result = NULL;
Model *model = NULL;
Datum ListNames, ListValues, ListOids;
bool isnull = false;
bool isnullValue = false;
bool isnullOid = false;
AlgorithmML algorithm;
if (t_thrd.proc->workingVersionNum < 92304) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Before GRAND VERSION NUM 92304, we do not support gs_model_warehouse.")));
}
HeapTuple tuple = SearchSysCache1(DB4AI_MODEL, CStringGetDatum(model_name));
if (!HeapTupleIsValid(tuple)) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("There is no model called \"%s\".", model_name)));
return NULL;
}
Form_gs_model_warehouse tuplePointer = (Form_gs_model_warehouse)GETSTRUCT(tuple);
const char *modelType = TextDatumGetCString(SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_model_type, &isnull));
algorithm = get_algorithm_ml(modelType);
if (algorithm == INVALID_ALGORITHM_ML) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("the type of model is invalid: %s", modelType)));
}
if (only_model) {
result = palloc0(sizeof(Model));
model = (Model *)result;
} else {
switch (algorithm) {
case LOGISTIC_REGRESSION:
case SVM_CLASSIFICATION:
case LINEAR_REGRESSION: {
result = palloc0(sizeof(ModelGradientDescent));
ModelGradientDescent *resGD = (ModelGradientDescent *)result;
get_SGD(&tuple, resGD, tuplePointer);
model = &(resGD->model);
} break;
case KMEANS: {
result = palloc0(sizeof(ModelKMeans));
ModelKMeans *resKmeans = (ModelKMeans *)result;
get_kmeans(&tuple, resKmeans);
model = &(resKmeans->model);
} break;
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("the type of model is invalid: %s", modelType)));
break;
}
}
model->algorithm = algorithm;
model->model_name = model_name;
model->exec_time_secs = tuplePointer->exec_time;
model->pre_time_secs = tuplePointer->pre_process_time;
model->processed_tuples = tuplePointer->processedtuples;
model->discarded_tuples = tuplePointer->discardedtuples;
model->return_type = tuplePointer->outputtype;
model->num_actual_iterations = tuplePointer->iterations;
model->sql = TextDatumGetCString(SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_query, &isnull));
ListNames = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_hyperparametersNames, &isnull);
ListValues = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_hyperparametersValues, &isnullValue);
ListOids = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_hyperparametersOids, &isnullOid);
if (!isnull && !isnullValue && !isnullOid) {
TupleToList<ListType::HYPERPARAMETERS>(model, &ListNames, &ListValues, &ListOids);
}
ListNames = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_coefNames, &isnull);
ListValues = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_coefValues, &isnullValue);
ListOids = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_coefOids, &isnullOid);
if (!isnull && !isnullValue && !isnullOid) {
TupleToList<ListType::COEFS>(model, &ListNames, &ListValues, &ListOids);
}
ListNames = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_trainingScoresName, &isnull);
ListValues = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_trainingScoresValue, &isnullValue);
if (!isnull && !isnullValue) {
TupleToList<ListType::SCORES>(model, &ListNames, &ListValues, NULL);
}
ReleaseSysCache(tuple);
return (Model *)result;
}
void elog_model(int level, const Model *model)
{
Oid typoutput;
bool typIsVarlena;
ListCell *lc;
StringInfoData buf;
initStringInfo(&buf);
const char* model_type = algorithm_ml_to_string(model->algorithm);
appendStringInfo(&buf, "\n:type %s", model_type);
appendStringInfo(&buf, "\n:sql %s", model->sql);
if (model->hyperparameters != nullptr) {
appendStringInfoString(&buf, "\n:hyperparameters");
foreach (lc, model->hyperparameters) {
Hyperparameter *hyperp = lfirst_node(Hyperparameter, lc);
getTypeOutputInfo(hyperp->type, &typoutput, &typIsVarlena);
appendStringInfo(&buf, "\n :%s %s", hyperp->name, OidOutputFunctionCall(typoutput, hyperp->value));
}
}
appendStringInfo(&buf, "\n:return type %u", model->return_type);
appendStringInfo(&buf, "\n:pre-processing time %lf s", model->pre_time_secs);
appendStringInfo(&buf, "\n:exec time %lf s", model->exec_time_secs);
appendStringInfo(&buf, "\n:processed %ld tuples", model->processed_tuples);
appendStringInfo(&buf, "\n:discarded %ld tuples", model->discarded_tuples);
appendStringInfo(&buf, "\n:actual number of iterations %d", model->num_actual_iterations);
if (model->train_info != nullptr) {
appendStringInfoString(&buf, "\n:info");
foreach (lc, model->train_info) {
TrainingInfo *info = lfirst_node(TrainingInfo, lc);
getTypeOutputInfo(info->type, &typoutput, &typIsVarlena);
appendStringInfo(&buf, "\n :%s %s", info->name, OidOutputFunctionCall(typoutput, info->value));
}
}
if (model->scores != nullptr) {
appendStringInfoString(&buf, "\n:scores");
foreach (lc, model->scores) {
TrainingScore *score = lfirst_node(TrainingScore, lc);
appendStringInfo(&buf, "\n :%s %.16g", score->name, score->value);
}
}
if (model->algorithm == LOGISTIC_REGRESSION ||
model->algorithm == SVM_CLASSIFICATION ||
model->algorithm == LINEAR_REGRESSION) {
ModelGradientDescent *model_gd = (ModelGradientDescent *)model;
appendStringInfoString(&buf, "\n:gradient_descent:");
appendStringInfo(&buf, "\n :algorithm %s", gd_get_algorithm(model_gd->model.algorithm)->name);
getTypeOutputInfo(FLOAT4ARRAYOID, &typoutput, &typIsVarlena);
appendStringInfo(&buf, "\n :weights %s", OidOutputFunctionCall(typoutput, model_gd->weights));
if (model_gd->ncategories > 0) {
Datum dt;
bool isnull;
bool first = true;
struct varlena *src_arr = (struct varlena *)DatumGetPointer(model_gd->categories);
ArrayType *arr = (ArrayType *)pg_detoast_datum(src_arr);
Assert(arr->elemtype == model->return_type);
ArrayIterator it = array_create_iterator(arr, 0);
getTypeOutputInfo(model->return_type, &typoutput, &typIsVarlena);
appendStringInfo(&buf, "\n :categories %d {", model_gd->ncategories);
while (array_iterate(it, &dt, &isnull)) {
Assert(!isnull);
appendStringInfo(&buf, "%s%s", first ? "" : ",", OidOutputFunctionCall(typoutput, dt));
first = false;
}
appendStringInfoString(&buf, "}");
array_free_iterator(it);
if (arr != (ArrayType *)src_arr)
pfree(arr);
}
}
elog(level, "Model=%s%s", model->model_name, buf.data);
pfree(buf.data);
}
template <ListType ltype> void ListToTuple(List *list, Datum *name, Datum *value, Datum *oid)
{
text *t_names, *t_values;
Datum *array_container = nullptr;
int iter = 0;
ArrayBuildState *astateName = NULL, *astateValue = NULL;
array_container = (Datum *)palloc0(list->length * sizeof(Datum));
foreach_cell(it, list)
{
switch (ltype) {
case ListType::HYPERPARAMETERS: {
Hyperparameter *cell = (Hyperparameter *)lfirst(it);
t_names = cstring_to_text(cell->name);
t_values = cstring_to_text(Datum_to_string(cell->value, cell->type, false));
array_container[iter] = ObjectIdGetDatum(cell->type);
astateName =
accumArrayResult(astateName, PointerGetDatum(t_names), false, TEXTOID, CurrentMemoryContext);
astateValue =
accumArrayResult(astateValue, PointerGetDatum(t_values), false, TEXTOID, CurrentMemoryContext);
iter++;
} break;
case ListType::COEFS: {
TrainingInfo *cell = (TrainingInfo *)lfirst(it);
t_names = cstring_to_text(cell->name);
t_values = cstring_to_text(Datum_to_string(cell->value, cell->type, false));
array_container[iter] = ObjectIdGetDatum(cell->type);
astateName =
accumArrayResult(astateName, PointerGetDatum(t_names), false, TEXTOID, CurrentMemoryContext);
astateValue =
accumArrayResult(astateValue, PointerGetDatum(t_values), false, TEXTOID, CurrentMemoryContext);
iter++;
} break;
case ListType::SCORES: {
TrainingScore *cell = (TrainingScore *)lfirst(it);
t_names = cstring_to_text(cell->name);
array_container[iter] = Float4GetDatum(cell->value);
astateName =
accumArrayResult(astateName, PointerGetDatum(t_names), false, TEXTOID, CurrentMemoryContext);
iter++;
} break;
default: {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Not support saving this data into model warehouse.")));
} break;
}
}
switch (ltype) {
case ListType::HYPERPARAMETERS:
case ListType::COEFS: {
*name = makeArrayResult(astateName, CurrentMemoryContext);
*value = makeArrayResult(astateValue, CurrentMemoryContext);
ArrayType *oid_array = construct_array(array_container, list->length, OIDOID, sizeof(Oid), true, 'i');
*oid = PointerGetDatum(oid_array);
} break;
case ListType::SCORES: {
*name = makeArrayResult(astateName, CurrentMemoryContext);
ArrayType *value_array =
construct_array(array_container, list->length, FLOAT4OID, sizeof(float4), FLOAT4PASSBYVAL, 'i');
*value = PointerGetDatum(value_array);
} break;
default: {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Not support saving this data into model warehouse.")));
break;
}
}
}
template <ListType listType> void TupleToList(Model *model, Datum *names, Datum *values, Datum *oids)
{
char *strNames, *strValues;
Oid tranOids;
Datum dtNames, dtValues, dtOid;
ArrayType *arrNames, *arrValues, *arrOids;
bool isnull;
ArrayIterator itOid = NULL;
arrNames = DatumGetArrayTypeP(*names);
arrValues = DatumGetArrayTypeP(*values);
ArrayIterator itName = array_create_iterator(arrNames, 0);
ArrayIterator itValue = array_create_iterator(arrValues, 0);
if (oids != NULL) {
arrOids = DatumGetArrayTypeP(*oids);
itOid = array_create_iterator(arrOids, 0);
}
while (array_iterate(itName, &dtNames, &isnull)) {
array_iterate(itValue, &dtValues, &isnull);
switch (listType) {
case ListType::HYPERPARAMETERS: {
array_iterate(itOid, &dtOid, &isnull);
strNames = TextDatumGetCString(dtNames);
strValues = TextDatumGetCString(dtValues);
tranOids = DatumGetObjectId(dtOid);
dtValues = string_to_datum(strValues, tranOids);
add_model_parameter<listType>(model, strNames, tranOids, dtValues);
} break;
case ListType::COEFS: {
array_iterate(itOid, &dtOid, &isnull);
strNames = TextDatumGetCString(dtNames);
strValues = TextDatumGetCString(dtValues);
tranOids = DatumGetObjectId(dtOid);
dtValues = string_to_datum(strValues, tranOids);
add_model_parameter<listType>(model, strNames, tranOids, dtValues);
} break;
case ListType::SCORES: {
strNames = TextDatumGetCString(dtNames);
add_model_parameter<listType>(model, strNames, 0, dtValues);
} break;
default: {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Not support fetching this data from model warehouse.")));
return;
}
}
}
return;
}
template <ListType ltype> static void add_model_parameter(Model *model, const char *name, Oid type, Datum value)
{
switch (ltype) {
case ListType::HYPERPARAMETERS: {
Hyperparameter *hyperp = (Hyperparameter *)palloc0(sizeof(Hyperparameter));
hyperp->name = pstrdup(name);
hyperp->type = type;
hyperp->value = value;
model->hyperparameters = lappend(model->hyperparameters, hyperp);
} break;
case ListType::COEFS: {
TrainingInfo *tinfo = (TrainingInfo *)palloc0(sizeof(TrainingInfo));
tinfo->name = pstrdup(name);
tinfo->type = type;
tinfo->value = value;
model->train_info = lappend(model->train_info, tinfo);
} break;
case ListType::SCORES: {
TrainingScore *tscore = (TrainingScore *)palloc0(sizeof(TrainingScore));
tscore->name = pstrdup(name);
tscore->value = DatumGetFloat4(value);
model->scores = lappend(model->scores, tscore);
} break;
default: {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Not support put this data into Model-struct.")));
return;
}
}
}
static Datum string_to_datum(const char *str, Oid datatype)
{
switch (datatype) {
case BOOLOID:
return DirectFunctionCall1(boolin, CStringGetDatum(str));
case INT1OID:
case INT2OID:
case INT4OID:
return Int32GetDatum(atoi(str));
case INT8OID:
return Int64GetDatum(atoi(str));
case VARCHAROID:
case BPCHAROID:
case CHAROID:
case TEXTOID:
return CStringGetTextDatum(str);
case FLOAT4OID:
case FLOAT8OID:
return DirectFunctionCall1(float8in, CStringGetDatum(str));
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("The type is not supported: %d", datatype)));
return CStringGetTextDatum(str);
}
}
/* store SGD model */
void store_SGD(Datum *values, bool *nulls, ModelGradientDescent *SGDmodel)
{
nulls[Anum_gs_model_modelData - 1] = true;
nulls[Anum_gs_model_modeldescribe - 1] = true;
values[Anum_gs_model_model_type - 1] = CStringGetTextDatum(algorithm_ml_to_string(SGDmodel->model.algorithm));
values[Anum_gs_model_weight - 1] = SGDmodel->weights;
if (SGDmodel->ncategories > 0) {
text *categoriesName, *categoriesValue;
ArrayBuildState *astate = NULL;
Datum dt;
bool isnull;
nulls[Anum_gs_model_coefNames - 1] = false;
categoriesName = cstring_to_text("categories");
astate = accumArrayResult(astate, PointerGetDatum(categoriesName), false, TEXTOID, CurrentMemoryContext);
values[Anum_gs_model_coefNames - 1] = makeArrayResult(astate, CurrentMemoryContext);
astate = NULL;
ArrayType *arr = (ArrayType *)pg_detoast_datum((struct varlena *)DatumGetPointer(SGDmodel->categories));
ArrayIterator it = array_create_iterator(arr, 0);
while (array_iterate(it, &dt, &isnull)) {
categoriesValue = cstring_to_text(Datum_to_string(dt, SGDmodel->model.return_type, false));
astate = accumArrayResult(astate, PointerGetDatum(categoriesValue), false, TEXTOID, CurrentMemoryContext);
}
values[Anum_gs_model_coefValues - 1] = makeArrayResult(astate, CurrentMemoryContext);
nulls[Anum_gs_model_coefValues - 1] = false;
}
}
/* get SGD model */
void get_SGD(HeapTuple *tuple, ModelGradientDescent *resGD, Form_gs_model_warehouse tuplePointer)
{
char *strValues;
Datum dtValues;
ArrayBuildState *astate = NULL;
bool isnull = false;
/* weight */
resGD->weights = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_weight, &isnull);
/* categories */
resGD->ncategories = 0;
Datum dtCat = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_coefValues, &isnull);
if (!isnull) {
ArrayType *arrValues = DatumGetArrayTypeP(dtCat);
ArrayIterator itValue = array_create_iterator(arrValues, 0);
while (array_iterate(itValue, &dtValues, &isnull)) {
resGD->ncategories++;
strValues = TextDatumGetCString(dtValues);
dtValues = string_to_datum(strValues, tuplePointer->outputtype);
astate = accumArrayResult(astate, dtValues, false, tuplePointer->outputtype, CurrentMemoryContext);
}
resGD->categories = makeArrayResult(astate, CurrentMemoryContext);
} else {
resGD->categories = PointerGetDatum(NULL);
}
}
/* get kmeans model */
void store_kmeans(Datum *values, bool *nulls, ModelKMeans *kmeansModel)
{
ArrayBuildState *astateName = NULL, *astateValue = NULL, *astateDescribe = NULL;
text *tValue, *txDescribe, *txCoodinate;
int lenDescribe = 200 * sizeof(char), lengthD = 0, lengthC = 0;
int lenCoodinate = 15 * kmeansModel->dimension * kmeansModel->actual_num_centroids;
WHCentroid *centroid;
double *coordinateContainer;
char *describeElem, *strCoordinates = (char *)palloc0(lenCoodinate);
values[Anum_gs_model_outputType - 1] = ObjectIdGetDatum(kmeansModel->model.return_type);
nulls[Anum_gs_model_modelData - 1] = true;
nulls[Anum_gs_model_weight - 1] = true;
values[Anum_gs_model_model_type - 1] = CStringGetTextDatum(algorithm_ml_to_string(KMEANS));
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->original_num_centroids), INT8OID, false));
astateName = accumArrayResult(astateName, CStringGetTextDatum("original_num_centroids"), false, TEXTOID,
CurrentMemoryContext);
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->actual_num_centroids), INT8OID, false));
astateName =
accumArrayResult(astateName, CStringGetTextDatum("actual_num_centroids"), false, TEXTOID, CurrentMemoryContext);
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->dimension), INT8OID, false));
astateName = accumArrayResult(astateName, CStringGetTextDatum("dimension"), false, TEXTOID, CurrentMemoryContext);
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->distance_function_id), INT8OID, false));
astateName =
accumArrayResult(astateName, CStringGetTextDatum("distance_function_id"), false, TEXTOID, CurrentMemoryContext);
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->seed), INT8OID, false));
astateName = accumArrayResult(astateName, CStringGetTextDatum("seed"), false, TEXTOID, CurrentMemoryContext);
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
astateName = accumArrayResult(astateName, CStringGetTextDatum("coordinates"), false, TEXTOID, CurrentMemoryContext);
for (uint32_t i = 0; i < kmeansModel->actual_num_centroids; i++) {
lengthD = 0;
describeElem = (char *)palloc0(lenDescribe);
centroid = kmeansModel->centroids + i;
lengthD = sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "id:%d,", centroid->id);
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "objective_function:%f,",
centroid->objective_function);
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "avg_distance_to_centroid:%f,",
centroid->avg_distance_to_centroid);
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "min_distance_to_centroid:%f,",
centroid->min_distance_to_centroid);
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "max_distance_to_centroid:%f,",
centroid->max_distance_to_centroid);
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "std_dev_distance_to_centroid:%f,",
centroid->std_dev_distance_to_centroid);
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "cluster_size:%d", centroid->cluster_size);
txDescribe = cstring_to_text(describeElem);
astateDescribe =
accumArrayResult(astateDescribe, PointerGetDatum(txDescribe), false, TEXTOID, CurrentMemoryContext);
coordinateContainer = centroid->coordinates;
lengthC += sprintf_s(strCoordinates + lengthC, lenCoodinate - lengthC, "(");
for (uint32_t j = 0; j < kmeansModel->dimension; j++) {
lengthC += sprintf_s(strCoordinates + lengthC, lenCoodinate - lengthC, "%f,", coordinateContainer[j]);
}
lengthC--;
lengthC += sprintf_s(strCoordinates + lengthC, lenCoodinate - lengthC, ")");
}
txCoodinate = cstring_to_text(strCoordinates);
astateValue = accumArrayResult(astateValue, PointerGetDatum(txCoodinate), false, TEXTOID, CurrentMemoryContext);
values[Anum_gs_model_modeldescribe - 1] = makeArrayResult(astateDescribe, CurrentMemoryContext);
values[Anum_gs_model_coefValues - 1] = makeArrayResult(astateValue, CurrentMemoryContext);
values[Anum_gs_model_coefNames - 1] = makeArrayResult(astateName, CurrentMemoryContext);
nulls[Anum_gs_model_coefValues - 1] = false;
nulls[Anum_gs_model_coefNames - 1] = false;
return;
}
void get_kmeans(HeapTuple *tuple, ModelKMeans *modelKmeans)
{
Datum dtValue, dtName;
bool isnull;
char *strValue, *strName, *coordinates = NULL;
uint32_t coefContainer;
int offset = 0;
WHCentroid *curseCent;
modelKmeans->model.algorithm = KMEANS;
/* coef */
Datum dtCoefValues = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_coefValues, &isnull);
ArrayType *arrValues = DatumGetArrayTypeP(dtCoefValues);
ArrayIterator itValue = array_create_iterator(arrValues, 0);
Datum dtCoefNames = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_coefNames, &isnull);
ArrayType *arrNames = DatumGetArrayTypeP(dtCoefNames);
ArrayIterator itName = array_create_iterator(arrNames, 0);
while (array_iterate(itName, &dtName, &isnull)) {
array_iterate(itValue, &dtValue, &isnull);
strName = TextDatumGetCString(dtName);
strValue = TextDatumGetCString(dtValue);
coefContainer = atoi(strValue);
if (strcmp(strName, "original_num_centroids") == 0) {
modelKmeans->original_num_centroids = coefContainer;
} else if (strcmp(strName, "actual_num_centroids") == 0) {
modelKmeans->actual_num_centroids = coefContainer;
} else if (strcmp(strName, "seed") == 0) {
modelKmeans->seed = coefContainer;
} else if (strcmp(strName, "dimension") == 0) {
modelKmeans->dimension = coefContainer;
} else if (strcmp(strName, "distance_function_id") == 0) {
modelKmeans->distance_function_id = coefContainer;
} else if (strcmp(strName, "coordinates") == 0) {
coordinates = strValue;
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INTERNAL_ERROR),
errmsg("the coef should not be here in KMEANS: %s", strName)));
}
}
modelKmeans->centroids =
reinterpret_cast<WHCentroid *>(palloc0(sizeof(WHCentroid) * modelKmeans->actual_num_centroids));
/* describe */
Datum dtDescribe = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_modeldescribe, &isnull);
ArrayType *arrDescribe = DatumGetArrayTypeP(dtDescribe);
ArrayIterator itDescribe = array_create_iterator(arrDescribe, 0);
while (array_iterate(itDescribe, &dtName, &isnull)) {
curseCent = modelKmeans->centroids + offset;
strName = TextDatumGetCString(dtName);
coordinates = splitStringFillCoordinates(curseCent, coordinates, modelKmeans->dimension);
splitStringFillCentroid(curseCent, strName);
offset++;
}
}
void splitStringFillCentroid(WHCentroid *curseCent, char *strDescribe)
{
char *cur, *name, *context = NULL;
Datum dtCur;
name = strtok_r(strDescribe, ":,", &context);
cur = strtok_r(NULL, ":,", &context);
while (cur != NULL and name != NULL) {
if (strcmp(name, "id") == 0) {
dtCur = string_to_datum(cur, INT8OID);
curseCent->id = DatumGetUInt32(dtCur);
} else if (strcmp(name, "objective_function") == 0) {
dtCur = string_to_datum(cur, FLOAT8OID);
curseCent->objective_function = DatumGetFloat8(dtCur);
} else if (strcmp(name, "avg_distance_to_centroid") == 0) {
dtCur = string_to_datum(cur, FLOAT8OID);
curseCent->avg_distance_to_centroid = DatumGetFloat8(dtCur);
} else if (strcmp(name, "min_distance_to_centroid") == 0) {
dtCur = string_to_datum(cur, FLOAT8OID);
curseCent->min_distance_to_centroid = DatumGetFloat8(dtCur);
} else if (strcmp(name, "max_distance_to_centroid") == 0) {
dtCur = string_to_datum(cur, FLOAT8OID);
curseCent->max_distance_to_centroid = DatumGetFloat8(dtCur);
} else if (strcmp(name, "std_dev_distance_to_centroid") == 0) {
dtCur = string_to_datum(cur, FLOAT8OID);
curseCent->std_dev_distance_to_centroid = DatumGetFloat8(dtCur);
} else if (strcmp(name, "cluster_size") == 0) {
dtCur = string_to_datum(cur, INT8OID);
curseCent->cluster_size = DatumGetUInt64(dtCur);
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INTERNAL_ERROR),
errmsg("this description should not be here in KMEANS: %s", cur)));
}
name = strtok_r(NULL, ":,", &context);
cur = strtok_r(NULL, ":,", &context);
}
}
char *splitStringFillCoordinates(WHCentroid *curseCent, char *strCoordinates, int dimension)
{
char *cur, *context = NULL;
Datum dtCur;
int iter = 0;
double *res = (double *)palloc0(dimension * sizeof(double));
while (iter < dimension) {
if (iter == 0) {
cur = strtok_r(strCoordinates, ")(,", &context);
} else {
cur = strtok_r(NULL, ")(,", &context);
}
if (cur != NULL) {
dtCur = string_to_datum(cur, FLOAT8OID);
res[iter] = DatumGetFloat8(dtCur);
iter++;
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("the Coordinates result seems not match their dimension or actual_num_centroids.")));
}
}
curseCent->coordinates = res;
return context;
}

View File

@ -0,0 +1,24 @@
#This is the main CMAKE for build all components.
AUX_SOURCE_DIRECTORY(${CMAKE_CURRENT_SOURCE_DIR} TGT_commands_SRC)
set(TGT_commands_INC
${PROJECT_OPENGS_DIR}/contrib/log_fdw
${PROJECT_TRUNK_DIR}/distribute/bin/gds
${PROJECT_SRC_DIR}/include/libcomm
${PROJECT_SRC_DIR}/include
${PROJECT_SRC_DIR}/lib/gstrace
${LZ4_INCLUDE_PATH}
${LIBCGROUP_INCLUDE_PATH}
${LIBORC_INCLUDE_PATH}
${EVENT_INCLUDE_PATH}
${PROTOBUF_INCLUDE_PATH}
${ZLIB_INCLUDE_PATH}
)
set(commands_DEF_OPTIONS ${MACRO_OPTIONS})
set(commands_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
list(REMOVE_ITEM commands_COMPILE_OPTIONS -fPIC)
set(commands_COMPILE_OPTIONS ${commands_COMPILE_OPTIONS} -std=c++14 -fPIE)
set(commands_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
add_static_objtarget(gausskernel_db4ai_commands TGT_commands_SRC TGT_commands_INC "${commands_DEF_OPTIONS}" "${commands_COMPILE_OPTIONS}" "${commands_LINK_OPTIONS}")

View File

@ -0,0 +1,22 @@
#---------------------------------------------------------------------------------------
#
# IDENTIFICATION
# src/gausskernel/dbmind/commands/Makefile
#
# ---------------------------------------------------------------------------------------
subdir = src/gausskernel/dbmind/db4ai/commands
top_builddir = ../../../../..
include $(top_builddir)/src/Makefile.global
ifneq "$(MAKECMDGOALS)" "clean"
ifneq "$(MAKECMDGOALS)" "distclean"
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
-include $(DEPEND)
endif
endif
endif
OBJS = create_model.o predict_by.o
include $(top_srcdir)/src/gausskernel/common.mk

View File

@ -0,0 +1,751 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* command.h
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/commands/create_model.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/create_model.h"
#include "postgres.h"
#include "knl/knl_variable.h"
#include "db4ai/model_warehouse.h"
#include "db4ai/hyperparameter_validation.h"
#include "catalog/indexing.h"
#include "executor/executor.h"
#include "executor/nodeKMeans.h"
#include "nodes/value.h"
#include "parser/analyze.h"
#include "rewrite/rewriteHandler.h"
#include "utils/snapmgr.h"
#include "tcop/tcopprot.h"
#include "utils/lsyscache.h"
#include "utils/rel.h"
#include "workload/workload.h"
#include "executor/nodeGD.h"
#include "db4ai/aifuncs.h"
#include "utils/builtins.h"
extern void exec_simple_plan(PlannedStmt *plan); // defined in postgres.cpp
bool verify_pgarray(ArrayType const * pg_array, int32_t n); // defined in kmeans.cpp
/*
* Common setup needed by both normal execution and EXPLAIN ANALYZE.
* This setup is adapted from SetupForCreateTableAs
*/
static Query *setup_for_create_model(Query *query, /* IntoClause *into, */ const char *queryString,
ParamListInfo params /* , DestReceiver *dest */)
{
List *rewritten = NIL;
Assert(query->commandType == CMD_SELECT);
/*
* Parse analysis was done already, but we still have to run the rule
* rewriter. We do not do AcquireRewriteLocks: we assume the query either
* came straight from the parser, or suitable locks were acquired by
* plancache.c.
*
* Because the rewriter and planner tend to scribble on the input, we make
* a preliminary copy of the source querytree. This prevents problems in
* the case that CTAS is in a portal or plpgsql function and is executed
* repeatedly. (See also the same hack in EXPLAIN and PREPARE.)
*/
rewritten = QueryRewrite((Query *)copyObject(query));
/* SELECT should never rewrite to more or less than one SELECT query */
if (list_length(rewritten) != 1) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INTERNAL_ERROR),
errmsg("Unexpected rewrite result for CREATE MODEL statement")));
}
query = (Query *)linitial(rewritten);
return query;
}
// Create an GradientDescent execution node with a given configuration
static GradientDescent *create_gd_node(AlgorithmML algorithm, List *hyperparameters, DestReceiverTrainModel *dest)
{
GradientDescent *gd_node = makeNode(GradientDescent);
gd_node->algorithm = algorithm;
configure_hyperparameters(algorithm, hyperparameters, dest->model, gd_node);
if (gd_node->seed == 0) {
gd_node->seed = time(NULL); // it is not set to zero again (zero is the epoch in the past)
update_model_hyperparameter(dest->model, "seed", INT4OID, Int32GetDatum(gd_node->seed));
}
gd_node->plan.type = T_GradientDescent;
gd_node->plan.targetlist = makeGradientDescentExpr(algorithm, nullptr, 1);
dest->targetlist = gd_node->plan.targetlist;
return gd_node;
}
// Add a GradientDescent operator at the root of the plan
static PlannedStmt *add_GradientDescent_to_plan(PlannedStmt *plan, AlgorithmML algorithm, List *hyperparameters,
DestReceiverTrainModel *dest)
{
GradientDescent *gd_node = create_gd_node(algorithm, hyperparameters, dest);
gd_node->plan.lefttree = plan->planTree;
plan->planTree = &gd_node->plan;
return plan;
}
static DistanceFunction get_kmeans_distance(const char *distance_func)
{
DistanceFunction distance = KMEANS_L2_SQUARED;
if (strcmp(distance_func, "L1") == 0)
distance = KMEANS_L1;
else if (strcmp(distance_func, "L2") == 0)
distance = KMEANS_L2;
else if (strcmp(distance_func, "L2_Squared") == 0)
distance = KMEANS_L2_SQUARED;
else if (strcmp(distance_func, "Linf") == 0)
distance = KMEANS_LINF;
else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("No known distance function chosen. Current candidates are: "
"L1, L2, L2_Squared (default), Linf")));
}
return distance;
}
static SeedingFunction get_kmeans_seeding(const char *seeding_func)
{
SeedingFunction seeding = KMEANS_RANDOM_SEED;
if (strcmp(seeding_func, "Random++") == 0)
seeding = KMEANS_RANDOM_SEED;
else if (strcmp(seeding_func, "KMeans||") == 0)
seeding = KMEANS_BB;
else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("No known seeding function chosen. Current candidates are: Random++ (default), KMeans||")));
}
return seeding;
}
static KMeans *create_kmeans_node(AlgorithmML const algorithm, List *hyperparameters, DestReceiverTrainModel *dest)
{
KMeans *kmeans_node = makeNode(KMeans);
char *distance_func = nullptr;
char *seeding_func = nullptr;
double tolerance = 0.;
int32_t num_iterations = 0;
int32_t num_centroids = 0;
int32_t const max_num_centroids = 1000000;
int32_t batch_size = 0;
int32_t num_features = 0;
int32_t external_seed = 0;
int32_t verbosity = 0;
auto kmeans_model = reinterpret_cast<ModelKMeans *>(dest->model);
HyperparameterValidation validation;
memset_s(&validation, sizeof(HyperparameterValidation), 0, sizeof(HyperparameterValidation));
kmeans_node->algorithm = algorithm;
kmeans_node->plan.type = T_KMeans;
kmeans_model->model.return_type = INT4OID;
set_hyperparameter<int32_t>("max_iterations", &num_iterations, hyperparameters, 10, dest->model, &validation);
if (unlikely(num_iterations <= 0)) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("max_iterations must be in [1, %d]", INT_MAX)));
} else {
kmeans_node->parameters.num_iterations = num_iterations;
}
set_hyperparameter<int32_t>("num_centroids", &num_centroids, hyperparameters, 10, dest->model, &validation);
if (unlikely((num_centroids <= 0) || (num_centroids > max_num_centroids))) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("num_centroids must be in [1, %d]", max_num_centroids)));
} else {
kmeans_node->parameters.num_centroids = num_centroids;
}
set_hyperparameter<double>("tolerance", &tolerance, hyperparameters, 0.00001, dest->model, &validation);
if (unlikely((tolerance <= 0.) || (tolerance > 1.))) {
ereport(ERROR,
(errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("tolerance must be in (0, 1.0]")));
} else {
kmeans_node->parameters.tolerance = tolerance;
}
set_hyperparameter<int32_t>("batch_size", &batch_size, hyperparameters, 10, dest->model, &validation);
if (unlikely((batch_size <= 0) || (batch_size > max_num_centroids))) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("batch_size must be in [1, %d]", max_num_centroids)));
} else {
kmeans_node->description.batch_size = batch_size;
}
set_hyperparameter<int32_t>("num_features", &num_features, hyperparameters, 2, dest->model, &validation);
if (unlikely((num_features <= 0) || (num_features > max_num_centroids))) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("num_features must be in [1, %d]", max_num_centroids)));
} else {
kmeans_node->description.n_features = num_features;
}
set_hyperparameter<char *>("distance_function", &distance_func, hyperparameters, "L2_Squared", dest->model,
&validation);
kmeans_node->description.distance = get_kmeans_distance(distance_func);
set_hyperparameter<char *>("seeding_function", &seeding_func, hyperparameters, "Random++", dest->model,
&validation);
kmeans_node->description.seeding = get_kmeans_seeding(seeding_func);
set_hyperparameter<int32_t>("verbose", &verbosity, hyperparameters, false, dest->model, &validation);
if (verbosity < 0 || verbosity > 2)
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Verbosity level must be between 0 (no output), 1 (less output), or 2 (full output)")));
else
kmeans_node->description.verbosity = static_cast<Verbosity>(verbosity);
/*
* unfortunately the system parses an int64_t as T_Float whenever the it does not fit into a int32_t
* thus, an int64_t might be internally parsed as a T_Integer or a T_Float depending on whether
* the precision fits into an int32_t. thus, we accept a small seed (int32_t) that is xor'ed with
* a random long internal seed.
*/
set_hyperparameter<int32_t>("seed", &external_seed, hyperparameters, 0, dest->model, &validation);
/*
* the seed used for the algorithm is the xor of the seed provided by the user with
* a random (but fixed) internal seed. as long as the internal seed is kept unchanged
* results will be reproducible (see nodeKMeans.cpp)
*/
if (external_seed < 0)
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("seed must be in [0, %d]", INT_MAX)));
else
kmeans_node->parameters.external_seed = static_cast<uint64_t>(external_seed);
/*
* these fields are propagated all the way to store_model and used for prediction
* the value of fields of ModelKMeans not set here change during execution
* and thus are set in the very end, when the model is about to be stored
*/
kmeans_model->dimension = kmeans_node->description.n_features;
kmeans_model->original_num_centroids = kmeans_node->parameters.num_centroids;
kmeans_model->distance_function_id = kmeans_node->description.distance;
pfree(distance_func);
pfree(seeding_func);
return kmeans_node;
}
// Add a k-means operator at the root of the plan
static PlannedStmt *add_kmeans_to_plan(PlannedStmt *plan, AlgorithmML algorithm, List *hyperparameters,
DestReceiverTrainModel *dest)
{
if (unlikely(algorithm != KMEANS)) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Algorithm is not the expected %u (k-means). Provided %u", KMEANS, algorithm)));
}
KMeans *kmeans_node = create_kmeans_node(algorithm, hyperparameters, dest);
kmeans_node->plan.lefttree = plan->planTree;
plan->planTree = &kmeans_node->plan;
return plan;
}
// Add the ML algorithm at the root of the plan according to the CreateModelStmt
static PlannedStmt *add_create_model_to_plan(CreateModelStmt *stmt, PlannedStmt *plan, DestReceiverTrainModel *dest)
{
PlannedStmt *result = NULL;
switch (stmt->algorithm) {
case LOGISTIC_REGRESSION:
case SVM_CLASSIFICATION:
case LINEAR_REGRESSION: {
result = add_GradientDescent_to_plan(plan, stmt->algorithm, stmt->hyperparameters, dest);
break;
}
case KMEANS: {
result = add_kmeans_to_plan(plan, stmt->algorithm, stmt->hyperparameters, dest);
break;
}
case INVALID_ALGORITHM_ML:
default: {
char *s = "logistic_regression, svm_classification, linear_regression, kmeans";
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Architecture %s is not supported. Supported architectures: %s", stmt->architecture, s)));
}
}
return result;
}
// Create the query plan with the appropriate machine learning model
PlannedStmt *plan_create_model(CreateModelStmt *stmt, const char *query_string, ParamListInfo params,
DestReceiver *dest)
{
Query *query = (Query *)stmt->select_query;
PlannedStmt *plan = NULL;
query = setup_for_create_model(query, query_string, params);
/* plan the query */
plan = pg_plan_query(query, 0, params);
// Inject the GradientDescent node at the root of the plan
plan = add_create_model_to_plan(stmt, plan, (DestReceiverTrainModel *)dest);
return plan;
}
// Prepare the DestReceiver for training
void configure_dest_receiver_train_model(DestReceiverTrainModel *dest, AlgorithmML algorithm, const char *model_name,
const char *sql)
{
switch (algorithm) {
case LOGISTIC_REGRESSION:
case LINEAR_REGRESSION:
case SVM_CLASSIFICATION: {
dest->model = (Model *)palloc0(sizeof(ModelGradientDescent));
break;
}
case KMEANS: {
dest->model = (Model *)palloc0(sizeof(ModelKMeans));
break;
}
default: {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Unsupported model type in model warehouse %d", algorithm)));
}
}
dest->model->algorithm = algorithm;
dest->model->model_name = pstrdup(model_name);
dest->model->sql = sql;
dest->targetlist = nullptr;
}
// /*
// * ExecCreateTableAs -- execute a CREATE TABLE AS command
// */
void exec_create_model(CreateModelStmt *stmt, const char *queryString, ParamListInfo params, char *completionTag)
{
#ifdef ENABLE_MULTIPLE_NODES
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("No support for distributed scenarios yet.")));
#endif
DestReceiverTrainModel *dest = NULL;
PlannedStmt *plan = NULL;
QueryDesc *queryDesc = NULL;
ScanDirection dir;
/*
* Create the tuple receiver object and insert hyperp it will need
*/
dest = (DestReceiverTrainModel *)CreateDestReceiver(DestTrainModel);
configure_dest_receiver_train_model(dest, (AlgorithmML)stmt->algorithm, stmt->model, queryString);
plan = plan_create_model(stmt, queryString, params, (DestReceiver *)dest);
/*
* Use a snapshot with an updated command ID to ensure this query sees
* results of any previously executed queries. (This could only matter if
* the planner executed an allegedly-stable function that changed the
* database contents, but let's do it anyway to be parallel to the EXPLAIN
* code path.)
*/
PushCopiedSnapshot(GetActiveSnapshot());
UpdateActiveSnapshotCommandId();
/* Create a QueryDesc, redirecting output to our tuple receiver */
queryDesc = CreateQueryDesc(plan, queryString, GetActiveSnapshot(), InvalidSnapshot, &dest->dest, params, 0);
#ifdef ENABLE_MULTIPLE_NODES
if (ENABLE_WORKLOAD_CONTROL && (IS_PGXC_COORDINATOR)) {
#else
if (ENABLE_WORKLOAD_CONTROL) {
#endif
/* Check if need track resource */
u_sess->exec_cxt.need_track_resource = WLMNeedTrackResource(queryDesc);
}
/* call ExecutorStart to prepare the plan for execution */
ExecutorStart(queryDesc, 0);
/* workload client manager */
if (ENABLE_WORKLOAD_CONTROL) {
WLMInitQueryPlan(queryDesc);
dywlm_client_manager(queryDesc);
}
dir = ForwardScanDirection;
/* run the plan */
ExecutorRun(queryDesc, dir, 0L);
/* save the rowcount if we're given a completionTag to fill */
if (completionTag != NULL) {
errno_t rc;
rc = snprintf_s(completionTag, COMPLETION_TAG_BUFSIZE, COMPLETION_TAG_BUFSIZE - 1,
"MODEL CREATED. PROCESSED %lu", queryDesc->estate->es_processed);
securec_check_ss(rc, "\0", "\0");
}
/* and clean up */
ExecutorFinish(queryDesc);
ExecutorEnd(queryDesc);
FreeQueryDesc(queryDesc);
PopActiveSnapshot();
}
static void store_gd_expr_in_model(Datum dt, Oid type, int col, GradientDescentExprField field, Model *model,
ModelGradientDescent *model_gd, TupleDesc tupdesc)
{
switch (field) {
case GD_EXPR_ALGORITHM:
Assert(type == INT4OID);
model_gd->model.algorithm = (AlgorithmML)DatumGetInt32(dt);
break;
case GD_EXPR_OPTIMIZER:
break; // Ignore field
case GD_EXPR_RESULT_TYPE:
Assert(type == OIDOID);
model->return_type = DatumGetUInt32(dt);
break;
case GD_EXPR_NUM_ITERATIONS:
Assert(type == INT4OID);
model->num_actual_iterations = DatumGetInt32(dt);
break;
case GD_EXPR_EXEC_TIME_MSECS:
Assert(type == FLOAT4OID);
model->exec_time_secs = DatumGetFloat4(dt) / 1000.0;
break;
case GD_EXPR_PROCESSED_TUPLES:
Assert(type == INT4OID);
model->processed_tuples = DatumGetInt32(dt);
break;
case GD_EXPR_DISCARDED_TUPLES:
Assert(type == INT4OID);
model->discarded_tuples = DatumGetInt32(dt);
break;
case GD_EXPR_WEIGHTS:
Assert(type == FLOAT4ARRAYOID);
model_gd->weights = datumCopy(dt, tupdesc->attrs[col]->attbyval, tupdesc->attrs[col]->attlen);
break;
case GD_EXPR_CATEGORIES: {
ArrayType *arr = (ArrayType *)DatumGetPointer(dt);
model_gd->ncategories = ARR_DIMS(arr)[0];
model_gd->categories =
datumCopy(dt, tupdesc->attrs[col]->attbyval, tupdesc->attrs[col]->attlen);
} break;
default:
(void)type;
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Model warehouse for GradientDescent field %d not implemented", field)));
break;
}
}
static void store_tuple_gd_in_model_warehouse(TupleTableSlot *slot, DestReceiverTrainModel *dest)
{
Assert(dest->targetlist != nullptr);
TupleDesc tupdesc = slot->tts_tupleDescriptor;
Model *model = dest->model;
model->pre_time_secs = 0.0;
ModelGradientDescent *model_gd = (ModelGradientDescent *)model;
model_gd->ncategories = 0;
model_gd->model.algorithm = INVALID_ALGORITHM_ML; // undefined
int col = 0;
ListCell *lc;
foreach (lc, dest->targetlist) {
TargetEntry *target = lfirst_node(TargetEntry, lc);
GradientDescentExpr *expr = (GradientDescentExpr *)target->expr;
if (!slot->tts_isnull[col]) {
Datum dt = slot->tts_values[col];
Oid type = tupdesc->attrs[col]->atttypid;
if ((expr->field & GD_EXPR_SCORE) != 0) {
Assert(type == FLOAT4OID);
TrainingScore *score = (TrainingScore *)palloc0(sizeof(TrainingScore));
score->name = pstrdup(target->resname);
score->value = DatumGetFloat4(dt);
model->scores = lappend(model->scores, score);
} else {
store_gd_expr_in_model(dt, type, col, expr->field, model, model_gd, tupdesc);
}
}
col++;
}
}
static void store_kmeans_data_in_model(uint32_t natts, TupleDesc tupdesc, Datum *values, bool *nulls,
Model *model, ModelKMeans *model_kmeans)
{
ArrayType *centroid_ids = nullptr;
ArrayType *centroid_coordinates = nullptr;
ArrayType *objective_functions = nullptr;
ArrayType *avg_distances = nullptr;
ArrayType *min_distances = nullptr;
ArrayType *max_distances = nullptr;
ArrayType *std_dev_distances = nullptr;
ArrayType *cluster_sizes = nullptr;
/* these are the inner-facing arrays */
int32_t *centroid_ids_data = nullptr;
double *centroid_coordiates_data = nullptr;
double *objective_functions_data = nullptr;
double *avg_distances_data = nullptr;
double *min_distances_data = nullptr;
double *max_distances_data = nullptr;
double *std_dev_distances_data = nullptr;
int64_t *cluster_sizes_data = nullptr;
Oid oid = 0;
Datum attr = 0;
/*
* for tuple at a time we only use one centroid at a time
*/
for (uint32_t a = 0; a < natts; ++a) {
if (unlikely(nulls[a])) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
errmsg("Encountered null attribute %u when serializing k-means model when it should not", a)));
}
oid = tupdesc->attrs[a]->atttypid;
attr = values[a];
/*
* this switch has to match exactly the schema of the row we return (see nodeKMeans.cpp)
* there is a single row (quite big in general) thus the switch executes only once
*/
switch (a) {
case 0:
// centroids ids of type INT4ARRAYOID
Assert(oid == INT4ARRAYOID);
centroid_ids = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
centroid_ids_data = reinterpret_cast<int32_t *>(ARR_DATA_PTR(centroid_ids));
break;
case 1:
// centroids coordinates of type FLOAT8ARRAYOID
Assert(oid == FLOAT8ARRAYOID);
/*
* in tuple at a time, this memory reference is valid until we store the centroid
* in the model warehouse
*/
centroid_coordinates = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
centroid_coordiates_data = reinterpret_cast<double *>(ARR_DATA_PTR(centroid_coordinates));
break;
case 2:
// value of the objective functions (per cluster) of type FLOAT8ARRAYOID
Assert(oid == FLOAT8ARRAYOID);
objective_functions = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
objective_functions_data = reinterpret_cast<double *>(ARR_DATA_PTR(objective_functions));
break;
case 3:
// avg distance of the clusters of type FLOAT8ARRAYOID
Assert(oid == FLOAT8ARRAYOID);
avg_distances = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
avg_distances_data = reinterpret_cast<double *>(ARR_DATA_PTR(avg_distances));
break;
case 4:
// min distance of the clusters of type FLOAT8ARRAYOID
Assert(oid == FLOAT8ARRAYOID);
min_distances = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
min_distances_data = reinterpret_cast<double *>(ARR_DATA_PTR(min_distances));
break;
case 5:
// max distance of the clusters of type FLOAT8ARRAYOID
Assert(oid == FLOAT8ARRAYOID);
max_distances = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
max_distances_data = reinterpret_cast<double *>(ARR_DATA_PTR(max_distances));
break;
case 6:
// standard deviation of clusters of type FLOAT8ARRAYOID
Assert(oid == FLOAT8ARRAYOID);
std_dev_distances = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
std_dev_distances_data = reinterpret_cast<double *>(ARR_DATA_PTR(std_dev_distances));
break;
case 7:
// cluster sizes of type INT8ARRAYOID
Assert(oid == INT8ARRAYOID);
cluster_sizes = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
cluster_sizes_data = reinterpret_cast<int64_t *>(ARR_DATA_PTR(cluster_sizes));
break;
case 8:
// num good points of type INT8OID
Assert(oid == INT8OID);
model->processed_tuples = DatumGetInt64(attr);
break;
case 9:
// num bad point of type INT8OID
Assert(oid == INT8OID);
model->discarded_tuples = DatumGetInt64(attr);
break;
case 10:
// seedings time (secs) of type FLOAT8OID
Assert(oid == FLOAT8OID);
model->pre_time_secs = DatumGetFloat8(attr);
break;
case 11:
// execution time (secs) of type FLOAT8OID
Assert(oid == FLOAT8OID);
model->exec_time_secs = DatumGetFloat8(attr);
break;
case 12:
// actual number of iterations INT4OID
Assert(oid == INT4OID);
model->num_actual_iterations = DatumGetInt32(attr);
break;
case 13:
// actual number of centroids INT4OID
Assert(oid == INT4OID);
model_kmeans->actual_num_centroids = DatumGetInt32(attr);
break;
case 14:
// seed used for computations
Assert(oid == INT8OID);
model_kmeans->seed = DatumGetInt64(attr);
break;
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Unknown attribute %u when serializing k-means model", a)));
}
}
uint32_t const actual_num_centroids = model_kmeans->actual_num_centroids;
uint32_t const dimension = model_kmeans->dimension;
uint32_t centroid_coordinates_offset = 0;
WHCentroid *current_centroid = nullptr;
/*
* at this point we have extracted all the attributes and the memory representation
* of the model can be constructed so that it can be stored in the model warehouse
*/
model_kmeans->centroids = reinterpret_cast<WHCentroid *>(palloc0(sizeof(WHCentroid) * actual_num_centroids));
/*
* we fill in the information of every centroid
*/
for (uint32_t current_centroid_idx = 0; current_centroid_idx < actual_num_centroids; ++current_centroid_idx) {
current_centroid = model_kmeans->centroids + current_centroid_idx;
current_centroid->id = centroid_ids_data[current_centroid_idx];
current_centroid->objective_function = objective_functions_data[current_centroid_idx];
current_centroid->avg_distance_to_centroid = avg_distances_data[current_centroid_idx];
current_centroid->min_distance_to_centroid = min_distances_data[current_centroid_idx];
current_centroid->max_distance_to_centroid = max_distances_data[current_centroid_idx];
current_centroid->std_dev_distance_to_centroid = std_dev_distances_data[current_centroid_idx];
current_centroid->cluster_size = cluster_sizes_data[current_centroid_idx];
current_centroid->coordinates = centroid_coordiates_data + centroid_coordinates_offset;
centroid_coordinates_offset += dimension;
}
}
static void store_tuple_kmeans_in_model_warehouse(TupleTableSlot *slot, DestReceiverTrainModel *dest)
{
/*
* sanity checks
*/
Assert(slot != NULL);
Assert(!slot->tts_isempty);
Assert(slot->tts_nvalid == NUM_ATTR_OUTPUT);
Assert(slot->tts_tupleDescriptor != NULL);
Assert(!TTS_HAS_PHYSICAL_TUPLE(slot));
TupleDesc tupdesc = slot->tts_tupleDescriptor;
auto model_kmeans = reinterpret_cast<ModelKMeans *>(dest->model);
Model *model = &model_kmeans->model;
if (unlikely(slot->tts_isempty))
return;
uint32_t const natts = slot->tts_nvalid;
/*
* the slot contains a virtual tuple and thus we can access its attributs directly
*/
Datum *values = slot->tts_values;
bool *nulls = slot->tts_isnull;
if (unlikely(!values && !nulls)) {
ereport(ERROR,
(errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("Empty arrays values and nulls")));
}
store_kmeans_data_in_model(natts, tupdesc, values, nulls, model, model_kmeans);
}
static void store_tuple_in_model_warehouse(TupleTableSlot *slot, DestReceiver *self)
{
DestReceiverTrainModel *dest = (DestReceiverTrainModel *)self;
Model *model = dest->model;
switch (model->algorithm) {
case LOGISTIC_REGRESSION:
case SVM_CLASSIFICATION:
case LINEAR_REGRESSION:
store_tuple_gd_in_model_warehouse(slot, dest);
break;
case KMEANS:
store_tuple_kmeans_in_model_warehouse(slot, dest);
break;
case INVALID_ALGORITHM_ML:
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Unsupported model type %d", static_cast<int>(model->algorithm))));
break;
}
store_model(model);
}
static void do_nothing_startup(DestReceiver *self, int operation, TupleDesc typehyperp)
{
/* do nothing */
}
static void do_nothing_cleanup(DestReceiver *self)
{
/* this is used for both shutdown and destroy methods */
}
DestReceiver *CreateTrainModelDestReceiver()
{
DestReceiverTrainModel *dr = (DestReceiverTrainModel *)palloc0(sizeof(DestReceiverTrainModel));
DestReceiver *result = &dr->dest;
result->rStartup = do_nothing_startup;
result->receiveSlot = store_tuple_in_model_warehouse;
result->rShutdown = do_nothing_cleanup;
result->rDestroy = do_nothing_cleanup;
result->mydest = DestTrainModel;
return result;
}

View File

@ -0,0 +1,153 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/commands/predict_by.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/predict_by.h"
#include "catalog/pg_type.h"
#include "db4ai/model_warehouse.h"
#include "nodes/parsenodes_common.h"
#include "parser/parse_expr.h"
#include "utils/array.h"
#include "utils/builtins.h"
#include "db4ai/gd.h"
#define DEBUG_MODEL_RETURN_TYPE INT4OID // Set manually the return type of model, until available from catalog
/*
* functions relevant to k-means and defined in kmeans.cpp
*/
ModelPredictor kmeans_predict_prepare(Model const * model);
Datum kmeans_predict(ModelPredictor model, Datum *data, bool *nulls, Oid *types, int nargs);
struct PredictionByData {
PredictorInterface *api;
ModelPredictor model_predictor;
};
static PredictionByData *initialize_predict_by_data(Model *model)
{
PredictionByData *result = (PredictionByData *)palloc0(sizeof(PredictionByData));
// Initialize the API handlers
switch (model->algorithm) {
case LOGISTIC_REGRESSION:
case SVM_CLASSIFICATION:
case LINEAR_REGRESSION: {
result->api = (PredictorInterface *)palloc0(sizeof(PredictorInterface));
result->api->predict = gd_predict;
result->api->prepare = gd_predict_prepare;
break;
}
case KMEANS: {
result->api = (PredictorInterface *)palloc0(sizeof(PredictorInterface));
result->api->predict = kmeans_predict;
result->api->prepare = kmeans_predict_prepare;
break;
}
case INVALID_ALGORITHM_ML:
default: {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Model type %d is not supported for prediction", (int)model->algorithm)));
break;
}
}
// Prepare the in memory version of the model for efficient prediction
result->model_predictor = result->api->prepare(model);
return result;
}
Datum db4ai_predict_by(PG_FUNCTION_ARGS)
{
// First argument is the model, the following ones are the inputs to the model predictor
if (PG_NARGS() < 2) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("INVALID Number of parameters %d", PG_NARGS())));
}
int var_args_size = PG_NARGS() - 1;
Datum *var_args = &fcinfo->arg[1]; // We skip the model name
bool *nulls = &fcinfo->argnull[1];
Oid *types = &fcinfo->argTypes[1];
PredictionByData *prediction_by_data;
if (fcinfo->flinfo->fn_extra != NULL) {
prediction_by_data = (PredictionByData *)fcinfo->flinfo->fn_extra;
} else {
// Initialize the model, and save the object as state of the function.
// The memory allocation is done in the function context. So, the model
// will be deallocated automatically when the function finishes
MemoryContext oldContext = MemoryContextSwitchTo(fcinfo->flinfo->fn_mcxt);
text *model_name_text = PG_GETARG_TEXT_P(0);
char *model_name = text_to_cstring(model_name_text);
Model *model = get_model(model_name, false);
prediction_by_data = initialize_predict_by_data(model);
fcinfo->flinfo->fn_extra = prediction_by_data;
pfree(model_name);
MemoryContextSwitchTo(oldContext);
}
Datum result =
prediction_by_data->api->predict(prediction_by_data->model_predictor, var_args, nulls, types, var_args_size);
PG_RETURN_DATUM(result);
}
// These functions are only a wrapper for the generic function. We need this wrapper
// to be compliant with openGauss type system that specifies the return type
Datum db4ai_predict_by_bool(PG_FUNCTION_ARGS)
{
return db4ai_predict_by(fcinfo);
}
Datum db4ai_predict_by_int32(PG_FUNCTION_ARGS)
{
return db4ai_predict_by(fcinfo);
}
Datum db4ai_predict_by_int64(PG_FUNCTION_ARGS)
{
return db4ai_predict_by(fcinfo);
}
Datum db4ai_predict_by_float4(PG_FUNCTION_ARGS)
{
return db4ai_predict_by(fcinfo);
}
Datum db4ai_predict_by_float8(PG_FUNCTION_ARGS)
{
return db4ai_predict_by(fcinfo);
}
Datum db4ai_predict_by_numeric(PG_FUNCTION_ARGS)
{
return db4ai_predict_by(fcinfo);
}
Datum db4ai_predict_by_text(PG_FUNCTION_ARGS)
{
return db4ai_predict_by(fcinfo);
}

View File

@ -0,0 +1,32 @@
# executor.cmake
set(TGT_executor_SRC
${CMAKE_CURRENT_SOURCE_DIR}/distance_functions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fp_ops.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hyperparameter_validation.cpp
)
set(TGT_executor_INC
${PROJECT_OPENGS_DIR}/contrib/log_fdw
${PROJECT_TRUNK_DIR}/distribute/bin/gds
${PROJECT_SRC_DIR}/include/libcomm
${PROJECT_SRC_DIR}/include
${PROJECT_SRC_DIR}/lib/gstrace
${LZ4_INCLUDE_PATH}
${LIBCGROUP_INCLUDE_PATH}
${LIBORC_INCLUDE_PATH}
${EVENT_INCLUDE_PATH}
${PROTOBUF_INCLUDE_PATH}
${ZLIB_INCLUDE_PATH}
)
set(executor_DEF_OPTIONS ${MACRO_OPTIONS})
set(executor_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
list(REMOVE_ITEM executor_COMPILE_OPTIONS -fPIC)
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64")
set(executor_COMPILE_OPTIONS ${executor_COMPILE_OPTIONS} -std=c++14 -fPIE -mavx)
else()
set(executor_COMPILE_OPTIONS ${executor_COMPILE_OPTIONS} -std=c++14 -fPIE)
endif()
set(executor_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
add_static_objtarget(gausskernel_db4ai_executor TGT_executor_SRC TGT_executor_INC "${executor_DEF_OPTIONS}" "${executor_COMPILE_OPTIONS}" "${executor_LINK_OPTIONS}")
add_subdirectory(gd)
add_subdirectory(kmeans)

View File

@ -0,0 +1,33 @@
#---------------------------------------------------------------------------------------
#
# IDENTIFICATION
# src/gausskernel/dbmind/db4ai/executor/Makefile
#
# ---------------------------------------------------------------------------------------
subdir = src/gausskernel/dbmind/db4ai/executor
top_builddir = ../../../../..
include $(top_builddir)/src/Makefile.global
PLATFORM_ARCH = $(shell uname -p)
ifeq ($(PLATFORM_ARCH),x86_64)
override CPPFLAGS += -mavx
endif
ifneq "$(MAKECMDGOALS)" "clean"
ifneq "$(MAKECMDGOALS)" "distclean"
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
-include $(DEPEND)
endif
endif
endif
SUBDIRS = gd kmeans
OBJS = fp_ops.o distance_functions.o hyperparameter_validation.o
include $(top_srcdir)/src/gausskernel/common.mk

View File

@ -0,0 +1,505 @@
/* *
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
---------------------------------------------------------------------------------------
distance_functions.cpp
Current set of distance functions that can be used (for k-means for example)
IDENTIFICATION
src/gausskernel/dbmind/db4ai/executor/distance_functions.cpp
---------------------------------------------------------------------------------------
* */
#include "postgres.h"
#include "db4ai/distance_functions.h"
#include "db4ai/fp_ops.h"
#include "db4ai/db4ai_cpu.h"
#include <cmath>
#if defined(__x86_64__) && defined(__SSE3__)
#include <pmmintrin.h>
#elif defined(__aarch64__) && defined(__ARM_NEON)
#include <arm_neon.h>
#endif
/*
* L1 distance (Manhattan)
* We sum using cascaded summation
* This version is unvectorized and is used in case vectorized instructions
* are not available or for the the case that the dimension is not a multiple
* of the width of the registers
*/
static force_inline double l1_non_vectorized(double const * p, double const * q, uint32_t const dimension)
{
double term = 0.;
double term_correction = 0.;
double distance = 0.;
double distance_correction = 0.;
twoDiff(q[0], p[0], &term, &term_correction);
term += term_correction;
// absolute value of the difference (hopefully done by clearing the sign bit)
distance = std::abs(term);
for (uint32_t d = 1; d < dimension; ++d) {
twoDiff(q[d], p[d], &term, &term_correction);
term += term_correction;
term = std::abs(term);
twoSum(distance, term, &distance, &term_correction);
distance_correction += term_correction;
}
return distance + distance_correction;
}
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
/*
* L1 distance (Manhattan)
* This version is vectorized using SSE or NEON and is used in case only 128-bit
* vectorized instructions are available
*/
static double l1_128(double const * p, double const * q, uint32_t const dimension)
{
if (unlikely(dimension == 0))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("L1 distance (128-bit): dimension must be larger than 0")));
double distance[2] = {0.};
/* the result of dimension modulo 2 */
uint32_t const dimension_remainder = dimension & 0x1U;
uint32_t const offset = 2U;
double distance_first_terms = 0.;
double local_distance_correction = 0.;
double global_distance_correction = 0.;
/*
* this will compute the very first terms of the distance
* (the ones that cannot be fully computed using simd registers, and
* thus have to be done with scalar computations)
*/
if (dimension_remainder > 0) {
/*
* this is gonna compute the 1-dimensional distance (the very first
* term in the whole distance computation)
*/
distance_first_terms = l1_non_vectorized(p, q, dimension_remainder);
}
/*
* if the dimension is < 2 then the term above is the whole distance
* otherwise we have at least one simd register we can fill
*/
if (unlikely(dimension < offset))
return distance_first_terms;
#if defined(__x86_64__)
__m128d const zero = _mm_setzero_pd();
__m128d const sign_mask = _mm_set1_pd(-0.0f);
__m128d sum = zero;
__m128d absolute_value;
__m128d sub;
__m128d p128, q128;
#else // aarch64
float64x2_t const zero = vdupq_n_f64(0);
float64x2_t sum = zero;
float64x2_t absolute_value;
float64x2_t sub;
float64x2_t p128, q128;
#endif
Assert(((dimension - dimension_remainder) & 0x1U) == 0U);
for (uint32_t d = dimension_remainder; d < dimension; d += offset) {
#if defined(__x86_64__)
p128 = _mm_loadu_pd(p + d);
q128 = _mm_loadu_pd(q + d);
sub = _mm_sub_pd(p128, q128);
/* this clears out the sign bit of sub (thus computing its absolute value) */
absolute_value = _mm_andnot_pd(sign_mask, sub);
sum = _mm_add_pd(sum, absolute_value);
#else // aarch64
p128 = vld1q_f64(p + d);
q128 = vld1q_f64(q + d);
sub = vsubq_f64(p128, q128);
/* this clears out the sign bit of sub - hopefully (thus computing its absolute value */
absolute_value = vabsq_f64(sub);
sum = vaddq_f64(sum, absolute_value);
#endif
}
/*
* in here we end up having a register with two terms that need to be added up to produce
* the final answer, we first perform an horizontal add to reduce two to one term
*/
#if defined(__x86_64__)
sum = _mm_hadd_pd(sum, zero);
#else // aarch64
sum = vpaddq_f64(sum, zero);
#endif
/*
* we extract the remaining term [x,0] to produce the final solution
*/
#if defined(__x86_64__)
_mm_storeu_pd(distance, sum);
#else // aarch64
vst1q_f64(distance, sum);
#endif
if (dimension_remainder > 0) {
/* d[0] = d[0] + distance_first_terms */
twoSum(distance[0], distance_first_terms, distance, &local_distance_correction);
global_distance_correction += local_distance_correction;
}
return distance[0] + global_distance_correction;
}
#endif
/*
* Squared Euclidean (default)
* We sum using cascaded summation
* This version is unvectorized and is used in case vectorized instructions
* are not available or for the the case that the dimension is not a multiple
* of the width of the registers
*/
static force_inline double l2_squared_non_vectorized(double const * p, double const * q, uint32_t const dimension)
{
double subtraction = 0.;
double subtraction_correction = 0.;
double term = 0.;
double term_correction = 0.;
double distance = 0.;
double distance_correction = 0.;
twoDiff(q[0], p[0], &subtraction, &subtraction_correction);
subtraction += subtraction_correction;
square(subtraction, &term, &term_correction);
term += term_correction;
distance = term;
for (uint32_t d = 1; d < dimension; ++d) {
twoDiff(q[d], p[d], &subtraction, &subtraction_correction);
subtraction += subtraction_correction;
square(subtraction, &term, &term_correction);
term += term_correction;
twoSum(distance, term, &distance, &term_correction);
distance_correction += term_correction;
}
return distance + distance_correction;
}
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
/*
* Squared Euclidean (default)
* This version is vectorized using SSE or NEON and is used in case only 128-bit
* vectorized instructions are available
*/
static double l2_squared_128(double const * p, double const * q, uint32_t const dimension)
{
if (unlikely(dimension == 0))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("L2 squared distance (128-bit): dimension must be larger than 0")));
double distance[2] = {0.};
/* the result of dimension modulo 2 */
uint32_t const dimension_remainder = dimension & 0x1U;
uint32_t const offset = 2U;
double distance_first_terms = 0.;
double local_distance_correction = 0.;
double global_distance_correction = 0.;
/*
* this will compute the very first terms of the distance
* (the ones that cannot be fully computed using simd registers, and
* thus have to be done with scalar computations)
*/
if (dimension_remainder > 0) {
/*
* this is gonna compute the 1-dimensional distance (the very first
* term in the whole distance computation)
*/
distance_first_terms = l2_squared_non_vectorized(p, q, dimension_remainder);
}
/*
* if the dimension is < 2 then the term above is the whole distance
* otherwise we have at least one simd register we can fill
*/
if (unlikely(dimension < offset))
return distance_first_terms;
#if defined(__x86_64__)
__m128d const zero = _mm_setzero_pd();
__m128d sum = zero;
__m128d square;
__m128d sub;
__m128d p128, q128;
#else // aarch64
float64x2_t const zero = vdupq_n_f64(0);
float64x2_t sum = zero;
float64x2_t square;
float64x2_t sub;
float64x2_t p128, q128;
#endif
Assert(((dimension - dimension_remainder) & 0x1U) == 0U);
for (uint32_t d = dimension_remainder; d < dimension; d += offset) {
#if defined(__x86_64__)
p128 = _mm_loadu_pd(p + d);
q128 = _mm_loadu_pd(q + d);
sub = _mm_sub_pd(p128, q128);
square = _mm_mul_pd(sub, sub);
sum = _mm_add_pd(sum, square);
#else // aarch64
p128 = vld1q_f64(p + d);
q128 = vld1q_f64(q + d);
sub = vsubq_f64(p128, q128);
square = vmulq_f64(sub, sub);
sum = vaddq_f64(sum, square);
#endif
}
/*
* in here we end up having a register with two terms that need to be added up to produce
* the final answer, we first perform an horizontal add to reduce two to one term
*/
#if defined(__x86_64__)
sum = _mm_hadd_pd(sum, zero);
#else // aarch64
sum = vpaddq_f64(sum, zero);
#endif
/*
* we extract the remaining term [x,0] to produce the final solution
*/
#if defined(__x86_64__)
_mm_storeu_pd(distance, sum);
#else // aarch64
vst1q_f64(distance, sum);
#endif
if (dimension_remainder > 0) {
/* d[0] = d[0] + distance_first_terms */
twoSum(distance[0], distance_first_terms, distance, &local_distance_correction);
global_distance_correction += local_distance_correction;
}
return distance[0] + global_distance_correction;
}
#endif
/*
* L infinity distance (Chebyshev)
* This version is unvectorized and is used in case vectorized instructions
* are not available or for the the case that the dimension is not a multiple
* of the width of the registers
*/
static force_inline double linf_non_vectorized(double const * p, double const * q, uint32_t const dimension)
{
double distance = 0.;
double term = 0.;
double term_correction = 0.;
twoDiff(q[0], p[0], &term, &term_correction);
term += term_correction;
// absolute value of the difference (hopefully done by clearing the sign bit)
distance = std::abs(term);
for (uint32_t d = 1; d < dimension; ++d) {
twoDiff(q[d], p[d], &term, &term_correction);
term += term_correction;
term = std::abs(term);
distance = term > distance ? term : distance;
}
return distance;
}
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
/*
* L infinity distance (Chebyshev)
* This version is vectorized using SSE or NEON and is used in case only 128-bit
* vectorized instructions are available
*/
static double linf_128(double const * p, double const * q, uint32_t const dimension)
{
if (unlikely(dimension == 0))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("L infinity distance (128-bit): dimension must be larger than 0")));
double distance[2] = {0.};
/* the result of dimension modulo 2 */
uint32_t const dimension_remainder = dimension & 0x1U;
uint32_t const offset = 2U;
double distance_first_terms = 0.;
/*
* this will compute the very first terms of the distance
* (the ones that cannot be fully computed using simd registers)
*/
if (dimension_remainder > 0)
distance_first_terms = linf_non_vectorized(p, q, dimension_remainder);
/*
* if the dimension is < 4 then the term above is the whole distance
* otherwise we have at least one simd register we can fill
*/
if (unlikely(dimension < offset))
return distance_first_terms;
#if defined(__x86_64__)
__m128d const zero = _mm_setzero_pd();
__m128d const sign_mask = _mm_set1_pd(-0.0f);
__m128d max = zero;
__m128d absolute_value;
__m128d sub;
__m128d p128, q128;
#else // aarch64
float64x2_t const zero = vdupq_n_f64(0);
float64x2_t max = zero;
float64x2_t absolute_value;
float64x2_t sub;
float64x2_t p128, q128;
#endif
Assert(((dimension - dimension_remainder) & 0x1U) == 0U);
for (uint32_t d = dimension_remainder; d < dimension; d += offset) {
#if defined(__x86_64__)
p128 = _mm_loadu_pd(p + d);
q128 = _mm_loadu_pd(q + d);
sub = _mm_sub_pd(p128, q128);
/* this clears out the sign bit of sub (thus computing its absolute value */
absolute_value = _mm_andnot_pd(sign_mask, sub);
max = _mm_max_pd(max, absolute_value);
#else // aarch64
p128 = vld1q_f64(p + d);
q128 = vld1q_f64(q + d);
sub = vsubq_f64(p128, q128);
/* this clears out the sign bit of sub - hopefully (thus computing its absolute value */
absolute_value = vabsq_f64(sub);
max = vmaxq_f64(max, absolute_value);
#endif
}
/*
* in here we end up having a register with two terms, from which we extract the max
* to produce the final answer
*/
#if defined(__x86_64__)
_mm_storeu_pd(distance, max);
#else // aarch64
vst1q_f64(distance, max);
#endif
double result = distance_first_terms;
for (uint32_t m = 0; m < offset; ++m)
result = distance[m] > result ? distance[m] : result;
return result;
}
#endif
/*
* L1 distance (Manhattan)
* This is the main function. It will be automatically vectorized
* if possible
*/
double l1(double const * p, double const * q, uint32_t const dimension)
{
if (unlikely(dimension == 0))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("L1 distance: dimension must be larger than 0")));
/*
* depending on the feature of the underlying processor we vectorized one way
* or another. in the worst case we do not vectorized at all
*/
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
return l1_128(p, q, dimension);
#else
return l1_non_vectorized(p, q, dimension);
#endif
}
/*
* Squared Euclidean (default)
* This is the main function. It will be automatically vectorized
* if possible
*/
double l2_squared(double const * p, double const * q, uint32_t const dimension)
{
if (unlikely(dimension == 0))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("L2 squared distance: dimension must be larger than 0")));
/*
* depending on the feature of the underlying processor we vectorized one way
* or another. in the worst case we do not vectorized at all
*/
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
return l2_squared_128(p, q, dimension);
#else
return l2_squared_non_vectorized(p, q, dimension);
#endif
}
/*
* L2 distance (Euclidean)
* This is the main function. It will be automatically vectorized
* if possible
*/
double l2(double const * p, double const * q, uint32_t const dimension)
{
if (unlikely(dimension == 0))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("L2 distance: dimension must be larger than 0")));
/*
* this one is vectorized automatically (or not)
*/
double const l2_sq = l2_squared(p, q, dimension);
// we can replace this with exact computation via mpfr (more costly, but the best alternative
// for fixed precision)
return std::sqrt(l2_sq);
}
/*
* L infinity distance (Chebyshev)
* This is the main function. It will be automatically vectorized
* if possible
*/
double linf(double const * p, double const * q, uint32_t const dimension)
{
if (unlikely(dimension == 0))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("L infinity distance: dimension must be larger than 0")));
/*
* depending on the feature of the underlying processor we vectorized one way
* or another. in the worst case we do not vectorized at all
*/
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
return linf_128(p, q, dimension);
#else
return linf_non_vectorized(p, q, dimension);
#endif
}

View File

@ -0,0 +1,256 @@
/* *
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
---------------------------------------------------------------------------------------
fp_ops.cpp
Robust floating point operations
IDENTIFICATION
src/gausskernel/dbmind/db4ai/executor/fp_ops.cpp
---------------------------------------------------------------------------------------
* */
#include "nodes/execnodes.h"
#include "db4ai/fp_ops.h"
#include "db4ai/db4ai_cpu.h"
#define WITH_ROBUST_OPS
/*
* Knuth's high precision sum a + b (Shewchuk's version)
*/
void twoSum(double const a, double const b, double *sum, double *e)
{
*sum = a + b;
#ifdef WITH_ROBUST_OPS
double const bv = *sum - a;
double const av = *sum - bv;
double const be = b - bv;
double const ae = a - av;
*e = ae + be;
#else
*e = 0.;
#endif
}
/*
* The equivalent subtraction a - b (Shewchuk's version)
*/
void twoDiff(double const a, double const b, double *sub, double *e)
{
*sub = a - b;
#ifdef WITH_ROBUST_OPS
double const bv = a - *sub;
double const av = *sub + bv;
double const be = bv - b;
double const ae = a - av;
*e = ae + be;
#else
*e = 0.;
#endif
}
static force_inline void veltkamp_split(double p, double *p_hi, double *p_low)
{
uint32_t const shift = 27U; // ceil(53 / 2)
uint32_t c = ((1U << shift) + 1U) * p;
double p_big = c - p;
*p_hi = c - p_big;
*p_low = p - *p_hi;
}
/*
* High precision product a * b (Shewchuk's version of Dekker-Veltkamp)
*/
void twoMult(double const a, double const b, double *mult, double *e)
{
*mult = a * b;
#ifdef WITH_ROBUST_OPS
double a_hi, a_low;
double b_hi, b_low;
double e_1, e_2, e_3;
veltkamp_split(a, &a_hi, &a_low);
veltkamp_split(b, &b_hi, &b_low);
e_1 = *mult - (a_hi * b_hi);
e_2 = e_1 - (a_low * b_hi);
e_3 = e_2 - (a_hi * b_low);
*e = (a_low * b_low) - e_3;
#else
*e = 0.;
#endif
}
/*
* High precision square a * a (Shewchuk's version of Dekker-Veltkamp)
*/
void square(double a, double *square, double *e)
{
*square = a * a;
#ifdef WITH_ROBUST_OPS
double a_hi, a_low;
double e_1, e_3;
veltkamp_split(a, &a_hi, &a_low);
e_1 = *square - (a_hi * a_hi);
e_3 = e_1 - (a_low * (a_hi + a_hi));
*e = (a_low * a_low) - e_3;
#else
*e = 0.;
#endif
}
/*
* High precision division a / b = *div + *e (Shewchuk's version of Dekker-Veltkamp)
*/
void twoDiv(double const a, double const b, double *div, double *e)
{
#ifdef WITH_ROBUST_OPS
twoMult(a, 1 / b, div, e);
#else
*div = a / b;
*e = 0.;
#endif
}
/*
* Implementing the methods of the incremental statistics
*/
IncrementalStatistics IncrementalStatistics::operator + (IncrementalStatistics const & rhs) const
{
IncrementalStatistics sum = *this;
sum += rhs;
return sum;
}
IncrementalStatistics IncrementalStatistics::operator - (IncrementalStatistics const & rhs) const
{
IncrementalStatistics minus = *this;
minus -= rhs;
return minus;
}
IncrementalStatistics &IncrementalStatistics::operator += (IncrementalStatistics const & rhs)
{
// for the term corresponding to the variance, the terms add up but there is a correcting term to
// be added up as well. The proof of this can be found in the updated version of the HLD
double const total_r = rhs.total;
uint64_t const rhs_population = rhs.population;
double increment_s = 0.;
bool execute = (rhs_population > 0) && (population > 0);
// this branch is as easy as it can get because population is mostly > 0
if (likely(execute)) {
increment_s = ((total_r * total_r) / rhs_population) + ((total * total) / population) -
(((total_r + total) * (total_r + total)) / (rhs_population + population));
}
total += rhs.total;
population += rhs.population;
s += (rhs.s + increment_s);
max_value = rhs.max_value > max_value ? rhs.max_value : max_value;
min_value = rhs.min_value < min_value ? rhs.min_value : min_value;
return *this;
}
IncrementalStatistics &IncrementalStatistics::operator -= (IncrementalStatistics const & rhs)
{
uint64_t const current_population = population - rhs.population;
double const current_total = total - rhs.total;
double decrement_s = 0.;
bool const execute = (population > 0) && (current_population > 0) && (rhs.population > 0);
// this branch is as easy as it can get because population is mostly > 0
if (likely(execute)) {
decrement_s = ((current_total * current_total) / current_population) +
((rhs.total * rhs.total) / rhs.population) - ((total * total) / population);
}
total = current_total;
population = current_population;
s -= (rhs.s + decrement_s);
s = std::abs(s);
return *this;
}
double IncrementalStatistics::getMin() const
{
return min_value;
}
void IncrementalStatistics::setMin(double min)
{
min_value = min;
}
double IncrementalStatistics::getMax() const
{
return max_value;
}
void IncrementalStatistics::setMax(double max)
{
max_value = max;
}
double IncrementalStatistics::getTotal() const
{
return total;
}
void IncrementalStatistics::setTotal(double t)
{
total = t;
}
uint64_t IncrementalStatistics::getPopulation() const
{
return population;
}
void IncrementalStatistics::setPopulation(uint64_t pop)
{
population = pop;
}
double IncrementalStatistics::getEmpiricalMean() const
{
double mean = 0.;
if (likely(population > 0))
mean = total / population;
return mean;
}
double IncrementalStatistics::getEmpiricalVariance() const
{
double variance = 0.;
if (likely(population > 0))
variance = s / population;
return variance;
}
double IncrementalStatistics::getEmpiricalStdDev() const
{
return std::sqrt(getEmpiricalVariance());
}
bool IncrementalStatistics::reset()
{
memset_s(this, sizeof(IncrementalStatistics), 0, sizeof(IncrementalStatistics));
min_value = DBL_MAX;
return true;
}

View File

@ -0,0 +1,31 @@
# gd.cmake
set(TGT_gd_SRC
${CMAKE_CURRENT_SOURCE_DIR}/gd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linregr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/logregr.cpp
${CMAKE_CURRENT_SOURCE_DIR}/matrix.cpp
${CMAKE_CURRENT_SOURCE_DIR}/optimizer_gd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/optimizer_ngd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/predict.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shuffle_cache.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svm.cpp
)
set(TGT_gd_INC
${PROJECT_OPENGS_DIR}/contrib/log_fdw
${PROJECT_TRUNK_DIR}/distribute/bin/gds
${PROJECT_SRC_DIR}/include/libcomm
${PROJECT_SRC_DIR}/include
${PROJECT_SRC_DIR}/lib/gstrace
${LZ4_INCLUDE_PATH}
${LIBCGROUP_INCLUDE_PATH}
${LIBORC_INCLUDE_PATH}
${EVENT_INCLUDE_PATH}
${PROTOBUF_INCLUDE_PATH}
${ZLIB_INCLUDE_PATH}
)
set(gd_DEF_OPTIONS ${MACRO_OPTIONS})
set(gd_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS} -std=c++14 -fPIE)
list(REMOVE_ITEM gd_COMPILE_OPTIONS -fPIC)
set(gd_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
add_static_objtarget(gausskernel_db4ai_executor_gd TGT_gd_SRC TGT_gd_INC "${gd_DEF_OPTIONS}" "${gd_COMPILE_OPTIONS}" "${gd_LINK_OPTIONS}")

View File

@ -0,0 +1,22 @@
#---------------------------------------------------------------------------------------
#
# IDENTIFICATION
# src/gausskernel/dbmind/db4ai/executor/Makefile
#
# ---------------------------------------------------------------------------------------
subdir = src/gausskernel/dbmind/db4ai/executor/gd
top_builddir = ../../../../../..
include $(top_builddir)/src/Makefile.global
ifneq "$(MAKECMDGOALS)" "clean"
ifneq "$(MAKECMDGOALS)" "distclean"
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
-include $(DEPEND)
endif
endif
endif
OBJS = gd.o matrix.o optimizer_gd.o optimizer_ngd.o shuffle_cache.o logregr.o svm.o linregr.o predict.o
include $(top_srcdir)/src/gausskernel/common.mk

View File

@ -0,0 +1,357 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* gd.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/gd.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "postgres.h"
#include "utils/builtins.h"
#include "nodes/makefuncs.h"
#include "db4ai/gd.h"
#include "nodes/primnodes.h"
#include "utils/array.h"
const char *gd_get_optimizer_name(OptimizerML optimizer)
{
static const char* names[] = { "gd", "ngd" };
if (optimizer > OPTIMIZER_NGD)
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid optimizer %d", optimizer)));
return names[optimizer];
}
const char *gd_get_expr_name(GradientDescentExprField field)
{
const char* names[] = {
"algorithm",
"optimizer",
"result_type",
"num_iterations",
"exec_time_msecs",
"processed",
"discarded",
"weights",
"categories",
};
if ((field & GD_EXPR_SCORE) != 0)
return gd_get_metric_name(field & ~GD_EXPR_SCORE);
if (field > GD_EXPR_CATEGORIES)
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid GD expression field %d", field)));
return names[field];
}
Datum gd_float_get_datum(Oid type, gd_float value)
{
Datum datum = 0;
switch (type) {
case BOOLOID:
datum = BoolGetDatum(value != 0.0);
break;
case INT1OID:
datum = Int8GetDatum(value);
break;
case INT2OID:
datum = Int16GetDatum(value);
break;
case INT4OID:
datum = Int32GetDatum(value);
break;
case INT8OID:
datum = Int64GetDatum(value);
break;
case FLOAT4OID:
datum = Float4GetDatum(value);
break;
case FLOAT8OID:
datum = Float8GetDatum(value);
break;
case NUMERICOID:
datum = DirectFunctionCall1(float4_numeric, Float4GetDatum(value));
break;
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Oid type %d not yet supported", type)));
break;
}
return datum;
}
gd_float gd_datum_get_float(Oid type, Datum datum)
{
gd_float value = 0;
switch (type) {
case BOOLOID:
value = DatumGetBool(datum) ? 1.0 : 0.0;
break;
case INT1OID:
value = DatumGetInt8(datum);
break;
case INT2OID:
value = DatumGetInt16(datum);
break;
case INT4OID:
value = DatumGetInt32(datum);
break;
case INT8OID:
value = DatumGetInt64(datum);
break;
case FLOAT4OID:
value = DatumGetFloat4(datum);
break;
case FLOAT8OID:
value = DatumGetFloat8(datum);
break;
case NUMERICOID:
value = DatumGetFloat8(DirectFunctionCall1(numeric_float8_no_overflow, datum));
break;
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Oid type %d not yet supported", type)));
break;
}
return value;
}
char *gd_get_metric_name(int metric)
{
switch (metric) {
case METRIC_ACCURACY:
return "accuracy";
case METRIC_F1:
return "f1";
case METRIC_PRECISION:
return "precision";
case METRIC_RECALL:
return "recall";
case METRIC_LOSS:
return "loss";
case METRIC_MSE:
return "mse";
default:
Assert(false);
}
return nullptr;
}
extern GradientDescentAlgorithm gd_logistic_regression;
extern GradientDescentAlgorithm gd_svm_classification;
extern GradientDescentAlgorithm gd_linear_regression;
GradientDescentAlgorithm *gd_get_algorithm(AlgorithmML algorithm)
{
GradientDescentAlgorithm *gd_algorithm = nullptr;
switch (algorithm) {
case LOGISTIC_REGRESSION:
gd_algorithm = &gd_logistic_regression;
break;
case SVM_CLASSIFICATION:
gd_algorithm = &gd_svm_classification;
break;
case LINEAR_REGRESSION:
gd_algorithm = &gd_linear_regression;
break;
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid algorithm %d", algorithm)));
break;
}
return gd_algorithm;
}
// ////////////////////////////////////////////////////////////////////////////
// expressions for projections
static struct {
GradientDescentExprField field;
Oid fieldtype;
char *name;
} GradientDescentExpr_fields[] = {
{ GD_EXPR_ALGORITHM, INT4OID, "algorithm"},
{ GD_EXPR_OPTIMIZER, INT4OID, "optimizer"},
{ GD_EXPR_RESULT_TYPE, OIDOID, "result_type"},
{ GD_EXPR_NUM_ITERATIONS, INT4OID, "num_iterations"},
{ GD_EXPR_EXEC_TIME_MSECS, FLOAT4OID, "exec_time_msecs"},
{ GD_EXPR_PROCESSED_TUPLES, INT4OID, "processed_tuples"},
{ GD_EXPR_DISCARDED_TUPLES, INT4OID, "discarded_tuples"},
{ GD_EXPR_WEIGHTS, FLOAT4ARRAYOID, "weights"},
};
static GradientDescentExpr *makeGradientDescentExpr(GradientDescentExprField field, Oid fieldtype)
{
GradientDescentExpr *xpr = makeNode(GradientDescentExpr);
xpr->field = field;
xpr->fieldtype = fieldtype;
return xpr;
}
List *makeGradientDescentExpr(AlgorithmML algorithm, List *list, int field)
{
Expr *expr;
for (size_t i = 0; i < sizeof(GradientDescentExpr_fields) / sizeof(GradientDescentExpr_fields[0]); i++) {
expr = (Expr *)makeGradientDescentExpr(GradientDescentExpr_fields[i].field,
GradientDescentExpr_fields[i].fieldtype);
list = lappend(list, makeTargetEntry(expr, field++, GradientDescentExpr_fields[i].name, false));
}
// add metrics
GradientDescentAlgorithm *palgo = gd_get_algorithm(algorithm);
int metrics = palgo->metrics;
int metric = 1;
while (metrics != 0) {
if (metrics & metric) {
expr = (Expr *)makeGradientDescentExpr(makeGradientDescentExprFieldScore(metric), FLOAT4OID);
list = lappend(list, makeTargetEntry(expr, field++, gd_get_metric_name(metric), false));
metrics &= ~metric;
}
metric <<= 1;
}
// binary value mappings
if (dep_var_is_binary(palgo)) {
expr = (Expr *)makeGradientDescentExpr(GD_EXPR_CATEGORIES, TEXTARRAYOID);
list = lappend(list, makeTargetEntry(expr, field++, "categories", false));
}
return list;
}
Datum ExecGDExprScore(GradientDescentExprState *mlstate, bool *isNull)
{
Datum dt = 0;
bool hasp, hasr;
double precision, recall;
GradientDescentState *gd_state = (GradientDescentState *)mlstate->ps;
switch (mlstate->xpr->field & ~GD_EXPR_SCORE) {
case METRIC_LOSS:
dt = Float4GetDatum(gd_state->loss);
break;
case METRIC_ACCURACY:
dt = Float4GetDatum(get_accuracy(&gd_state->scores));
break;
case METRIC_F1: // 2 * (precision * recall) / (precision + recall)
precision = get_precision(&gd_state->scores, &hasp);
recall = get_recall(&gd_state->scores, &hasr);
if ((hasp && precision > 0) || (hasr && recall > 0)) {
dt = Float4GetDatum(2.0 * precision * recall / (precision + recall));
} else
*isNull = true;
break;
case METRIC_PRECISION:
precision = get_precision(&gd_state->scores, &hasp);
if (hasp) {
dt = Float4GetDatum(precision);
} else
*isNull = true;
break;
case METRIC_RECALL:
recall = get_recall(&gd_state->scores, &hasr);
if (hasr) {
dt = Float4GetDatum(recall);
} else
*isNull = true;
break;
case METRIC_MSE:
dt = Float4GetDatum(gd_state->scores.mse);
break;
default:
*isNull = true;
break;
}
return dt;
}
Datum ExecNonGDExprScore(GradientDescentExprState *mlstate, ExprContext *econtext, bool *isNull)
{
Datum dt = 0;
Oid typoutput;
GradientDescentState *gd_state = (GradientDescentState *)mlstate->ps;
GradientDescent *gd_node = (GradientDescent *)gd_state->ss.ps.plan;
ArrayBuildState *astate = NULL;
switch (mlstate->xpr->field) {
case GD_EXPR_ALGORITHM:
dt = Int32GetDatum(gd_node->algorithm);
break;
case GD_EXPR_OPTIMIZER:
dt = Int32GetDatum(gd_node->optimizer);
break;
case GD_EXPR_RESULT_TYPE:
typoutput = get_atttypid(gd_state, gd_node->targetcol);
dt = UInt32GetDatum(typoutput);
break;
case GD_EXPR_NUM_ITERATIONS:
dt = Int32GetDatum(gd_state->n_iterations);
break;
case GD_EXPR_EXEC_TIME_MSECS:
dt = Float4GetDatum(gd_state->usecs / 1000.0);
break;
case GD_EXPR_PROCESSED_TUPLES:
dt = Int32GetDatum(gd_state->processed);
break;
case GD_EXPR_DISCARDED_TUPLES:
dt = Int32GetDatum(gd_state->discarded);
break;
case GD_EXPR_WEIGHTS: {
gd_float *pw = gd_state->weights.data;
for (int i = 0; i < gd_state->weights.rows; i++)
astate = accumArrayResult(astate, Float4GetDatum(*pw++), false, FLOAT4OID, CurrentMemoryContext);
dt = makeArrayResult(astate, econtext->ecxt_per_query_memory);
} break;
case GD_EXPR_CATEGORIES:
typoutput = get_atttypid(gd_state, gd_node->targetcol);
for (int i = 0; i < gd_state->num_classes; i++)
astate =
accumArrayResult(astate, gd_state->binary_classes[i], false, typoutput, CurrentMemoryContext);
dt = makeArrayResult(astate, econtext->ecxt_per_query_memory);
break;
default:
*isNull = true;
break;
}
return dt;
}
Datum ExecEvalGradientDescent(GradientDescentExprState *mlstate, ExprContext *econtext, bool *isNull,
ExprDoneCond *isDone)
{
Datum dt = 0;
if (isDone != NULL)
*isDone = ExprSingleResult;
*isNull = false;
if ((mlstate->xpr->field & GD_EXPR_SCORE) != 0) {
dt = ExecGDExprScore(mlstate, isNull);
} else {
dt = ExecNonGDExprScore(mlstate, econtext, isNull);
}
return dt;
}

View File

@ -0,0 +1,80 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* linreg.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/linreg.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/gd.h"
static void linear_reg_gradients(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
Matrix *weights, Matrix *gradients)
{
Assert(features->rows > 0);
// xT * (x * w - y)
Matrix loss;
matrix_init(&loss, features->rows);
matrix_mult_vector(features, weights, &loss);
matrix_subtract(&loss, dep_var);
Matrix x_t;
matrix_init_transpose(&x_t, features);
matrix_mult_vector(&x_t, &loss, gradients);
matrix_release(&x_t);
matrix_release(&loss);
}
static double linear_reg_test(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
const Matrix *weights, Scores *scores)
{
Assert(features->rows > 0);
// loss = sum((x * w - y)^2) / 2m
Matrix errors;
matrix_init(&errors, features->rows);
matrix_mult_vector(features, weights, &errors);
matrix_subtract(&errors, dep_var);
matrix_square(&errors);
gd_float tuple_loss = matrix_get_sum(&errors) / features->rows;
scores->mse += tuple_loss;
tuple_loss /= 2;
matrix_release(&errors);
return tuple_loss;
}
static gd_float linear_reg_predict(const Matrix *features, const Matrix *weights)
{
// p = x * w
return matrix_dot(features, weights);
}
GradientDescentAlgorithm gd_linear_regression = {
"linear-regression",
GD_DEPENDENT_VAR_CONTINUOUS,
METRIC_MSE, // same as loss function
0.0,
0.0, // not necessary
linear_reg_gradients,
linear_reg_test,
linear_reg_predict,
};

View File

@ -0,0 +1,108 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* logreg.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/logreg.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/gd.h"
static void logreg_gradients(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
Matrix *weights, Matrix *gradients)
{
Assert(features->rows > 0);
// xT * ((1.0 / (1.0 + exp(-x*w))) - y)
Matrix sigma;
matrix_init(&sigma, features->rows);
matrix_mult_vector(features, weights, &sigma);
matrix_sigmoid(&sigma);
matrix_subtract(&sigma, dep_var);
Matrix x_t;
matrix_init_transpose(&x_t, features);
matrix_mult_vector(&x_t, &sigma, gradients);
matrix_release(&x_t);
matrix_release(&sigma);
}
static double logreg_test(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
const Matrix *weights, Scores *scores)
{
Assert(features->rows > 0);
Matrix predictions;
Matrix cost1;
Matrix cost2;
Matrix classification;
// p = 1.0 + exp(-x*w)
matrix_init(&predictions, features->rows);
matrix_mult_vector(features, weights, &predictions);
matrix_sigmoid(&predictions);
// compute relevance
matrix_init_clone(&classification, &predictions);
matrix_binary(&classification, 0.5, 0.0, 1.0);
matrix_relevance(&classification, dep_var, scores, 1.0);
matrix_release(&classification);
// compute loss using cross-entropy
// cost1 = y * log(p)
matrix_init_clone(&cost1, &predictions);
matrix_log(&cost1);
matrix_mult_entrywise(&cost1, dep_var);
// cost2 = (1-y) * log(1-p)
matrix_complement(&predictions);
matrix_log(&predictions);
matrix_init_clone(&cost2, dep_var);
matrix_complement(&cost2);
matrix_mult_entrywise(&cost2, &predictions);
// cost = sum(-cost1 - cost2) / N
matrix_negate(&cost1);
matrix_subtract(&cost1, &cost2);
gd_float tuple_loss = matrix_get_sum(&cost1) / features->rows;
matrix_release(&cost2);
matrix_release(&cost1);
matrix_release(&predictions);
return tuple_loss;
}
static gd_float logreg_predict(const Matrix *features, const Matrix *weights)
{
// p = 1.0 + exp(-x*w)
gd_float r = matrix_dot(features, weights);
r = 1.0 / (1.0 + exp(-r));
return r < 0.5 ? 0.0 : 1.0;
}
GradientDescentAlgorithm gd_logistic_regression = {
"logistic-regression",
GD_DEPENDENT_VAR_BINARY,
METRIC_ACCURACY | METRIC_F1 | METRIC_PRECISION | METRIC_RECALL | METRIC_LOSS,
0.0,
1.0,
logreg_gradients,
logreg_test,
logreg_predict,
};

View File

@ -0,0 +1,74 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* matrix.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/matrix.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/matrix.h"
#define MATRIX_LIMITED_OUTPUT 30
void matrix_print(const Matrix *matrix, StringInfo buf, bool full)
{
Assert(matrix != nullptr);
Assert(!matrix->transposed);
const gd_float *pf = matrix->data;
appendStringInfoChar(buf, '[');
for (int r = 0; r < matrix->rows; r++) {
if (!full && matrix->rows > MATRIX_LIMITED_OUTPUT && r > (MATRIX_LIMITED_OUTPUT / 2) &&
r < matrix->rows - (MATRIX_LIMITED_OUTPUT / 2)) {
if (matrix->columns > 1)
appendStringInfoString(buf, ",\n...");
else
appendStringInfoString(buf, ", ...");
r = matrix->rows - MATRIX_LIMITED_OUTPUT / 2;
continue;
}
if (matrix->columns > 1) {
if (r > 0)
appendStringInfoString(buf, ",\n");
appendStringInfoChar(buf, '[');
} else {
if (r > 0)
appendStringInfoString(buf, ", ");
}
for (int c = 0; c < matrix->columns; c++) {
if (c > 0)
appendStringInfoString(buf, ", ");
appendStringInfo(buf, "%.16g", *pf++);
}
if (matrix->columns > 1)
appendStringInfoChar(buf, ']');
}
appendStringInfoChar(buf, ']');
}
void elog_matrix(int elevel, const char *msg, const Matrix *matrix)
{
StringInfoData buf;
initStringInfo(&buf);
matrix_print(matrix, &buf, false);
ereport(elevel, (errmodule(MOD_DB4AI), errmsg("%s = %s", msg, buf.data)));
pfree(buf.data);
}

View File

@ -0,0 +1,78 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* optimizer_gd.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/optimizer_gd.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/gd.h"
// ////////////////////////////////////////////////////////////////////////
// gd: minibatch basic optimizer
typedef struct OptimizerMinibatch {
OptimizerGD opt;
const GradientDescentState *gd_state;
double learning_rate;
} OptimizerMinibatch;
static void opt_gd_end_iteration(OptimizerGD *optimizer)
{
OptimizerMinibatch *opt = (OptimizerMinibatch *)optimizer;
// decay the learning rate with decay^iterations
opt->learning_rate *= gd_get_node(opt->gd_state)->decay;
}
static void opt_gd_update_batch(OptimizerGD *optimizer, const Matrix *features, const Matrix *dep_var)
{
OptimizerMinibatch *opt = (OptimizerMinibatch *)optimizer;
// clear gradients of the batch
matrix_zeroes(&optimizer->gradients);
// update gradients
opt->gd_state->algorithm->gradients_callback(gd_get_node(opt->gd_state), features, dep_var, &optimizer->weights,
&optimizer->gradients);
elog_matrix(DEBUG1, "optimizer gd: gradients", &optimizer->gradients);
// add gradients to the model: weight -= alpha * gradients * scale / N
matrix_mult_scalar(&optimizer->gradients, opt->learning_rate / features->rows);
matrix_subtract(&optimizer->weights, &optimizer->gradients);
elog_matrix(DEBUG1, "optimizer gd: weights", &optimizer->weights);
}
static void opt_gd_release(OptimizerGD *optimizer)
{
pfree(optimizer);
}
OptimizerGD *gd_init_optimizer_gd(const GradientDescentState *gd_state)
{
OptimizerMinibatch *opt = (OptimizerMinibatch *)palloc0(sizeof(OptimizerMinibatch));
opt->opt.start_iteration = nullptr;
opt->opt.end_iteration = opt_gd_end_iteration;
opt->opt.update_batch = opt_gd_update_batch;
opt->opt.release = opt_gd_release;
opt->gd_state = gd_state;
opt->learning_rate = gd_get_node(gd_state)->learning_rate;
return &opt->opt;
}

View File

@ -0,0 +1,130 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* optimizer_ngd.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/optimizer_ngd.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/gd.h"
// ////////////////////////////////////////////////////////////////////////
// ngd: normalized gradient descent optimizer
//
// An adaptation of NG algorithm from:
// Ross, Stéphane, Paul Mineiro, and John Langford.
// "Normalized online learning." arXiv preprint arXiv:1305.6646 (2013).
typedef struct OptimizerNormalize {
OptimizerGD opt;
const GradientDescentState *gd_state;
double learning_rate;
bool learn; // only first iteration
double scale_rate;
Matrix scale_gradients;
} OptimizerNormalize;
static void opt_ngd_end_iteration(OptimizerGD *optimizer)
{
OptimizerNormalize *opt = (OptimizerNormalize *)optimizer;
// decay the learning rate with decay^iterations
opt->learning_rate *= gd_get_node(opt->gd_state)->decay;
// be sure that learns how to normalize only in the first iteration
opt->learn = false;
}
static void opt_ngd_update_batch(OptimizerGD *optimizer, const Matrix *features, const Matrix *dep_var)
{
OptimizerNormalize *opt = (OptimizerNormalize *)optimizer;
if (opt->learn) {
Assert(features->columns == opt->scale_gradients.rows);
gd_float *pf = features->data;
for (int r = 0; r < features->rows; r++) {
gd_float *pw = optimizer->weights.data;
gd_float *ps = opt->scale_gradients.data;
for (int c = 0; c < features->columns; c++) {
gd_float qx = *pf++;
qx *= qx;
if (qx > *ps) {
// update weights and scaling of gradients
*pw *= *ps / qx;
*ps = qx;
}
if (*ps > 0) {
// update scale rate
opt->scale_rate += qx / *ps;
}
ps++;
pw++;
}
}
}
// clear gradients of the batch
matrix_zeroes(&optimizer->gradients);
opt->gd_state->algorithm->gradients_callback(gd_get_node(opt->gd_state), features, dep_var, &optimizer->weights,
&optimizer->gradients);
elog_matrix(DEBUG1, "optimizer ngd: gradients", &optimizer->gradients);
// normalize gradients
gd_float *pg = optimizer->gradients.data;
gd_float *ps = opt->scale_gradients.data;
for (int r = 0; r < opt->scale_gradients.rows; r++) {
gd_float s = 0.0;
if (*ps > 0)
s = (1.0 / opt->scale_rate) / *ps;
*pg *= s;
ps++;
pg++;
}
// add gradients to the model: weight -= alpha * scale_rate * gradients * scale_gradients
// do not divide by the number of rows like in a simple minibatch
matrix_mult_scalar(&optimizer->gradients, opt->learning_rate);
matrix_subtract(&optimizer->weights, &optimizer->gradients);
elog_matrix(DEBUG1, "optimizer ngd: weights", &optimizer->weights);
}
static void opt_ngd_release(OptimizerGD *optimizer)
{
pfree(optimizer);
}
OptimizerGD *gd_init_optimizer_ngd(const GradientDescentState *gd_state)
{
OptimizerNormalize *opt = (OptimizerNormalize *)palloc0(sizeof(OptimizerNormalize));
opt->opt.start_iteration = nullptr;
opt->opt.end_iteration = opt_ngd_end_iteration;
opt->opt.update_batch = opt_ngd_update_batch;
opt->opt.release = opt_ngd_release;
opt->gd_state = gd_state;
opt->learning_rate = gd_get_node(gd_state)->learning_rate;
opt->learn = true;
opt->scale_rate = 0.0;
matrix_init(&opt->scale_gradients, gd_state->n_features);
return &opt->opt;
}

View File

@ -0,0 +1,125 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* predict.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/predict.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "postgres.h"
#include "db4ai/gd.h"
#include "db4ai/model_warehouse.h"
#include "db4ai/predict_by.h"
typedef struct GradientDescentPredictor {
GradientDescentAlgorithm *algorithm;
int ncategories;
Oid return_type;
Matrix weights;
Matrix features;
Datum *categories;
} GradientDescentPredictor;
ModelPredictor gd_predict_prepare(const Model *model)
{
GradientDescentPredictor *gdp = (GradientDescentPredictor *)palloc0(sizeof(GradientDescentPredictor));
ModelGradientDescent *gdp_model = (ModelGradientDescent *)model;
gdp->algorithm = gd_get_algorithm(gdp_model->model.algorithm);
gdp->ncategories = gdp_model->ncategories;
gdp->return_type = gdp_model->model.return_type;
ArrayType *arr = (ArrayType *)pg_detoast_datum((struct varlena *)DatumGetPointer(gdp_model->weights));
Assert(arr->elemtype == FLOAT4OID);
int coefficients = ARR_DIMS(arr)[0];
matrix_init(&gdp->weights, coefficients);
matrix_init(&gdp->features, coefficients);
Datum dt;
bool isnull;
gd_float *pf = gdp->weights.data;
ArrayIterator it = array_create_iterator(arr, 0);
while (array_iterate(it, &dt, &isnull)) {
Assert(!isnull);
*pf++ = DatumGetFloat4(dt);
}
array_free_iterator(it);
if (arr != (ArrayType *)DatumGetPointer(gdp_model->weights))
pfree(arr);
if (gdp->ncategories > 0) {
arr = (ArrayType *)pg_detoast_datum((struct varlena *)DatumGetPointer(gdp_model->categories));
gdp->categories = (Datum *)palloc(ARR_DIMS(arr)[0] * sizeof(Datum));
int cat = 0;
it = array_create_iterator(arr, 0);
while (array_iterate(it, &dt, &isnull)) {
Assert(!isnull);
gdp->categories[cat++] = dt;
}
array_free_iterator(it);
if (arr != (ArrayType *)DatumGetPointer(gdp_model->categories))
pfree(arr);
} else
gdp->categories = nullptr;
return (ModelPredictor *)gdp;
}
Datum gd_predict(ModelPredictor pred, Datum *values, bool *isnull, Oid *types, int ncolumns)
{
// extract coefficients from model
GradientDescentPredictor *gdp = (GradientDescentPredictor *)pred;
// extract the features
if (ncolumns != (int)gdp->weights.rows - 1)
elog(ERROR, "Invalid number of features for prediction, provided %d, expected %d", ncolumns,
gdp->weights.rows - 1);
gd_float *w = gdp->features.data;
for (int i = 0; i < ncolumns; i++) {
double value;
if (isnull[i])
value = 0.0; // default value for feature, it is not the target for sure
else
value = gd_datum_get_float(types[i], values[i]);
*w++ = value;
}
*w = 1.0; // bias
Datum result = 0;
gd_float r = gdp->algorithm->predict_callback(&gdp->features, &gdp->weights);
if (dep_var_is_binary(gdp->algorithm)) {
if (gdp->ncategories == 0) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INTERNAL_ERROR),
errmsg("For classification algorithms: %s, the number of categories should not be 0.",
gdp->algorithm->name)));
}
result = gdp->categories[0];
if (r != gdp->algorithm->min_class) {
if (gdp->ncategories == 2 && r == gdp->algorithm->max_class)
result = gdp->categories[1];
}
} else {
result = gd_float_get_datum(gdp->return_type, r);
}
return result;
}

View File

@ -0,0 +1,280 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* shuffle_cache.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/shuffle_cache.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/gd.h"
/*
* Shuffle using a limited cache is performed by caching and training batches
* in random order. For each new batch there are two random options:
* - append (there are free slots into the cache)
* - or train (and test) an existing batch and replace it by the new one
* At the end of the iteration, the cache is emptied randomly. At each step,
* one remaining batch is selected, trained and removed.
* At the end of the two phases, all batches have been visited only once in a
* random sequence, but the ability to shuffle batches depends on the cache size
* (the available working memory) and the batch size (a matrix of N rows
* by M features). The probability distribution is not uniform and initial
* batches have a higher probability than the last batches, but the shuffling
* is good enough and has a very small impact into the performance of GD.
*/
typedef struct ShuffleCache {
ShuffleGD shf;
int cache_size;
int *cache_batch;
Matrix **cache_features;
Matrix **cache_dep_var;
int cache_allocated;
int max_cache_usage;
int num_batches;
int batch_size;
int n_features;
int iteration;
int cached;
int next;
struct drand48_data rnd;
} ShuffleCache;
inline int32_t rnd_next(ShuffleCache *shf) {
long int r;
lrand48_r(&shf->rnd, &r);
return r;
}
static void update_batch(ShuffleCache *cache)
{
Assert(cache->next < cache->cached);
ereport(DEBUG1, (errmodule(MOD_DB4AI), errmsg("GD shuffle cache iteration %d train batch %d of %d",
cache->iteration, cache->cache_batch[cache->next] + 1, cache->num_batches)));
Matrix *features = cache->cache_features[cache->next];
Matrix *dep_var = cache->cache_dep_var[cache->next];
cache->shf.optimizer->update_batch(cache->shf.optimizer, features, dep_var);
if (features->rows < cache->batch_size) {
matrix_resize(features, cache->batch_size, cache->n_features);
matrix_resize(dep_var, cache->batch_size, 1);
}
}
static void swap_last(ShuffleCache *cache)
{
Matrix *features = cache->cache_features[cache->next];
cache->cache_features[cache->next] = cache->cache_features[cache->cached];
cache->cache_features[cache->cached] = features;
Matrix *dep_var = cache->cache_dep_var[cache->next];
cache->cache_dep_var[cache->next] = cache->cache_dep_var[cache->cached];
cache->cache_dep_var[cache->cached] = dep_var;
cache->cache_batch[cache->next] = cache->cache_batch[cache->cached];
}
static void cache_start_iteration(ShuffleGD *shuffle)
{
ShuffleCache *cache = (ShuffleCache *)shuffle;
cache->next = -1;
cache->cached = 0;
cache->num_batches = 0;
}
static void cache_end_iteration(ShuffleGD *shuffle)
{
ShuffleCache *cache = (ShuffleCache *)shuffle;
if (cache->iteration == 0) {
cache->max_cache_usage = cache->cache_size;
if (cache->max_cache_usage > cache->num_batches)
cache->max_cache_usage = cache->num_batches;
} else {
// empty the cache
while (cache->cached > 0) {
cache->next = rnd_next(cache) % cache->cached;
update_batch(cache);
cache->cached--;
if (cache->next != cache->cached)
swap_last(cache);
}
}
cache->iteration++;
}
static void cache_release(ShuffleGD *shuffle)
{
ShuffleCache *cache = (ShuffleCache *)shuffle;
for (int c = 0; c < cache->cache_allocated; c++) {
matrix_release(cache->cache_features[c]);
pfree(cache->cache_features[c]);
matrix_release(cache->cache_dep_var[c]);
pfree(cache->cache_dep_var[c]);
}
pfree(cache->cache_features);
pfree(cache->cache_dep_var);
pfree(cache->cache_batch);
pfree(cache);
}
static Matrix *cache_get(ShuffleGD *shuffle, Matrix **pdep_var)
{
ShuffleCache *cache = (ShuffleCache *)shuffle;
Assert(cache->next == -1);
Matrix *features;
Matrix *dep_var;
if (cache->iteration == 0) {
// special case, do not shuffle
Assert(cache->cached == 0);
cache->next = 0;
cache->cached++;
if (cache->cache_allocated == 0) {
features = (Matrix *)palloc0(sizeof(Matrix));
matrix_init(features, cache->batch_size, cache->n_features);
dep_var = (Matrix *)palloc0(sizeof(Matrix));
matrix_init(dep_var, cache->batch_size);
cache->cache_features[0] = features;
cache->cache_dep_var[0] = dep_var;
cache->cache_allocated++;
} else {
// reuse the batch, it has been already
features = cache->cache_features[0];
dep_var = cache->cache_dep_var[0];
}
} else {
// look for an empty slot, otherwise reuse one
cache->next = rnd_next(cache) % cache->max_cache_usage;
if (cache->next < cache->cached) {
// reuse slot
update_batch(cache);
features = cache->cache_features[cache->next];
dep_var = cache->cache_dep_var[cache->next];
} else {
// append
cache->next = cache->cached++;
if (cache->next == cache->cache_allocated) {
features = (Matrix *)palloc0(sizeof(Matrix));
matrix_init(features, cache->batch_size, cache->n_features);
dep_var = (Matrix *)palloc0(sizeof(Matrix));
matrix_init(dep_var, cache->batch_size);
cache->cache_features[cache->next] = features;
cache->cache_dep_var[cache->next] = dep_var;
cache->cache_allocated++;
} else {
features = cache->cache_features[cache->next];
dep_var = cache->cache_dep_var[cache->next];
}
}
}
cache->cache_batch[cache->next] = cache->num_batches;
cache->num_batches++;
*pdep_var = dep_var;
return features;
}
static void cache_unget(ShuffleGD *shuffle, int tuples)
{
ShuffleCache *cache = (ShuffleCache *)shuffle;
Assert(cache->next != -1);
Assert(cache->cached > 0);
if (tuples == 0) {
// ignore batch
if (cache->iteration == 0) {
cache->cached--;
} else {
// special case when the last batch is empty
// there are two potential cases
cache->cached--;
if (cache->next < cache->cached) {
// it is in the middle, swap it with the last
swap_last(cache);
}
}
cache->num_batches--;
} else {
if (tuples < cache->batch_size) {
// resize batch temporarily
Matrix *features = cache->cache_features[cache->next];
Matrix *dep_var = cache->cache_dep_var[cache->next];
matrix_resize(features, tuples, cache->n_features);
matrix_resize(dep_var, tuples, 1);
}
if (cache->iteration == 0) {
Assert(cache->next == 0);
update_batch(cache);
cache->cached--;
}
}
cache->next = -1;
}
ShuffleGD *gd_init_shuffle_cache(const GradientDescentState *gd_state)
{
int batch_size = gd_get_node(gd_state)->batch_size;
// check if a batch fits into memory
int64_t avail_mem = u_sess->attr.attr_memory.work_mem * 1024LL -
matrix_expected_size(gd_state->n_features) * 2; // weights & gradients
int batch_mem = matrix_expected_size(batch_size, gd_state->n_features) // features
+ matrix_expected_size(gd_state->n_features); // dep var
if (batch_mem > avail_mem)
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("batch size is too large for the available working memory")));
// initialize
ShuffleCache *cache = (ShuffleCache *)palloc0(sizeof(ShuffleCache));
srand48_r(gd_get_node(gd_state)->seed, &cache->rnd);
cache->shf.start_iteration = cache_start_iteration;
cache->shf.end_iteration = cache_end_iteration;
cache->shf.release = cache_release;
cache->shf.get = cache_get;
cache->shf.unget = cache_unget;
// cache for shuffle
cache->cache_size = avail_mem / (batch_mem + 2 * sizeof(Matrix *) + sizeof(int));
if (cache->cache_size == 0)
cache->cache_size = 1; // shuffle is not possible
ereport(NOTICE, (errmodule(MOD_DB4AI), errmsg("GD shuffle cache size %d", cache->cache_size)));
cache->batch_size = batch_size;
cache->n_features = gd_state->n_features;
cache->cache_batch = (int *)palloc(cache->cache_size * sizeof(int));
cache->cache_features = (Matrix **)palloc(cache->cache_size * sizeof(Matrix *));
cache->cache_dep_var = (Matrix **)palloc(cache->cache_size * sizeof(Matrix *));
cache->cache_allocated = 0;
cache->max_cache_usage = 0;
cache->iteration = 0;
return &cache->shf;
}

View File

@ -0,0 +1,101 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* svm.cpp
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/gd/svm.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/gd.h"
static void svmc_gradients(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
Matrix *weights, Matrix *gradients)
{
Assert(features->rows > 0);
// distances = 1 - y * (x * w)
Matrix distance;
matrix_init(&distance, features->rows);
matrix_mult_vector(features, weights, &distance);
matrix_mult_entrywise(&distance, dep_var);
matrix_complement(&distance);
Assert(distance.rows == dep_var->rows);
Assert(distance.columns == 1);
const gd_float *pf = features->data;
const gd_float *py = dep_var->data;
const gd_float *pd = distance.data;
gd_float *pg = gradients->data;
for (int r = 0; r < distance.rows; r++) {
gd_float y = *py++;
gd_float d = *pd++;
if (d > 0) {
for (int f = 0; f < features->columns; f++)
pg[f] -= y * pf[f];
}
pf += features->columns;
}
matrix_release(&distance);
matrix_mult_scalar(gradients, gd_node->lambda * 2.0);
}
static double svmc_test(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
const Matrix *weights, Scores *scores)
{
Assert(features->rows > 0);
Matrix distances;
Matrix predictions;
matrix_init(&distances, features->rows);
matrix_mult_vector(features, weights, &distances);
matrix_init_clone(&predictions, &distances);
matrix_positive(&predictions);
matrix_binary(&predictions, FLT_MIN, -1.0, 1.0);
matrix_relevance(&predictions, dep_var, scores, 1.0);
matrix_release(&predictions);
// cost = (1 - y * x * w)
matrix_mult_entrywise(&distances, dep_var);
matrix_complement(&distances);
matrix_positive(&distances);
gd_float tuple_loss = matrix_get_sum(&distances) / features->rows;
matrix_release(&distances);
return tuple_loss;
}
static gd_float svmc_predict(const Matrix *features, const Matrix *weights)
{
double r = matrix_dot(features, weights);
return r < 0 ? -1.0 : 1.0;
}
GradientDescentAlgorithm gd_svm_classification = {
"svm-classification",
GD_DEPENDENT_VAR_BINARY,
METRIC_ACCURACY | METRIC_F1 | METRIC_PRECISION | METRIC_RECALL | METRIC_LOSS,
-1.0,
1.0,
svmc_gradients,
svmc_test,
svmc_predict,
};

View File

@ -0,0 +1,401 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ---------------------------------------------------------------------------------------
*
*
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/executor/hyperparameter_validation.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "db4ai/hyperparameter_validation.h"
#include "nodes/plannodes.h"
#define ARRAY_LENGTH(x) sizeof(x) / sizeof((x)[0])
static void init_hyperparameter_validation(HyperparameterValidation *v, void *min_value, bool min_inclusive,
void *max_value, bool max_inclusive, const char **valid_values, int32_t valid_values_size)
{
v->min_value = min_value;
v->min_inclusive = min_inclusive;
v->max_value = max_value;
v->max_inclusive = max_inclusive;
v->valid_values = valid_values;
v->valid_values_size = valid_values_size;
}
// Definitions of hyperparameters
#define HYPERPARAMETER_BOOL(name, default_value, struct_name, attribute) \
{ \
name, BoolGetDatum(default_value), PointerGetDatum(NULL), PointerGetDatum(NULL), NULL, NULL, BOOLOID, 0, \
offsetof(struct_name, attribute), false, false \
}
#define HYPERPARAMETER_ENUM(name, default_value, enum_values, enum_values_size, enum_setter, struct_name, attribute) \
{ \
name, CStringGetDatum(default_value), PointerGetDatum(NULL), PointerGetDatum(NULL), enum_values, enum_setter, \
ANYENUMOID, enum_values_size, offsetof(struct_name, attribute), false, false \
}
#define HYPERPARAMETER_INT4(name, default_value, min, min_inclusive, max, max_inclusive, struct_name, attribute) \
{ \
name, Int32GetDatum(default_value), Int32GetDatum(min), Int32GetDatum(max), NULL, NULL, INT4OID, 0, \
offsetof(struct_name, attribute), min_inclusive, max_inclusive \
}
#define HYPERPARAMETER_FLOAT8(name, default_value, min, min_inclusive, max, max_inclusive, struct_name, attribute) \
{ \
name, Float8GetDatum(default_value), Float8GetDatum(min), Float8GetDatum(max), NULL, NULL, FLOAT8OID, 0, \
offsetof(struct_name, attribute), min_inclusive, max_inclusive \
}
const char* gd_optimizer_ml[] = {"gd", "ngd"};
// Used by linear regression and logistic regression
HyperparameterDefinition logistic_regression_hyperparameter_definitions[] = {
HYPERPARAMETER_INT4("batch_size", 1000, 1, true, INT32_MAX, true,
GradientDescent, batch_size),
HYPERPARAMETER_FLOAT8("decay", 0.95, 0.0, false, DBL_MAX, true,
GradientDescent, decay),
HYPERPARAMETER_FLOAT8("learning_rate", 0.8, 0.0, false, DBL_MAX, true,
GradientDescent, learning_rate),
HYPERPARAMETER_INT4("max_iterations", 100, 1, true, INT32_MAX, true,
GradientDescent, max_iterations),
HYPERPARAMETER_INT4("max_seconds", 0, 0, true, INT32_MAX, true,
GradientDescent, max_seconds),
HYPERPARAMETER_ENUM("optimizer", "gd", gd_optimizer_ml, ARRAY_LENGTH(gd_optimizer_ml), optimizer_ml_setter,
GradientDescent, optimizer),
HYPERPARAMETER_FLOAT8("tolerance", 0.0005, 0.0, false, DBL_MAX, true,
GradientDescent, tolerance),
HYPERPARAMETER_INT4("seed", 0, 0, true, INT32_MAX, true,
GradientDescent, seed),
HYPERPARAMETER_BOOL("verbose", false,
GradientDescent, verbose),
};
HyperparameterDefinition svm_hyperparameter_definitions[] = {
HYPERPARAMETER_INT4("batch_size", 1000, 1, true, INT32_MAX, true,
GradientDescent, batch_size),
HYPERPARAMETER_FLOAT8("decay", 0.95, 0.0, false, DBL_MAX, true,
GradientDescent, decay),
HYPERPARAMETER_FLOAT8("lambda", 0.01, 0.0, false, DBL_MAX, true,
GradientDescent, lambda),
HYPERPARAMETER_FLOAT8("learning_rate", 0.8, 0.0, false, DBL_MAX, true,
GradientDescent, learning_rate),
HYPERPARAMETER_INT4("max_iterations", 100, 1, true, INT32_MAX, true,
GradientDescent, max_iterations),
HYPERPARAMETER_INT4("max_seconds", 0, 0, true, INT32_MAX, true,
GradientDescent, max_seconds),
HYPERPARAMETER_ENUM("optimizer", "gd", gd_optimizer_ml, ARRAY_LENGTH(gd_optimizer_ml), optimizer_ml_setter,
GradientDescent, optimizer),
HYPERPARAMETER_FLOAT8("tolerance", 0.0005, 0.0, false, DBL_MAX, true,
GradientDescent, tolerance),
HYPERPARAMETER_INT4("seed", 0, 0, true, INT32_MAX, true,
GradientDescent, seed),
HYPERPARAMETER_BOOL("verbose", false,
GradientDescent, verbose),
};
HyperparameterDefinition kmeans_hyperparameter_definitions[] = {
/* nothing to do now, will do when needing */
};
void get_hyperparameter_definitions(AlgorithmML algorithm, HyperparameterDefinition **result, int32_t *result_size)
{
switch (algorithm) {
case LOGISTIC_REGRESSION:
case LINEAR_REGRESSION:
*result = logistic_regression_hyperparameter_definitions;
*result_size = ARRAY_LENGTH(logistic_regression_hyperparameter_definitions);
break;
case SVM_CLASSIFICATION:
*result = svm_hyperparameter_definitions;
*result_size = ARRAY_LENGTH(svm_hyperparameter_definitions);
break;
case KMEANS:
*result = kmeans_hyperparameter_definitions;
*result_size = ARRAY_LENGTH(kmeans_hyperparameter_definitions);
break;
case INVALID_ALGORITHM_ML:
default:
char *s = "logistic_regression, svm_classification, linear_regression, kmeans";
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Architecture is not supported. Supported architectures: %s", s)));
}
}
static void add_model_hyperparameter(Model *model, const char *name, Oid type, Datum value)
{
Hyperparameter *hyperp = (Hyperparameter *)palloc0(sizeof(Hyperparameter));
hyperp->name = pstrdup(name);
hyperp->type = type;
hyperp->value = value;
model->hyperparameters = lappend(model->hyperparameters, hyperp);
}
void update_model_hyperparameter(Model *model, const char *name, Oid type, Datum value) {
ListCell *lc;
foreach(lc, model->hyperparameters) {
Hyperparameter *hyperp = lfirst_node(Hyperparameter, lc);
if (strcmp(hyperp->name, name) == 0) {
hyperp->type = type;
hyperp->value = value;
return;
}
}
Assert(false);
}
// Set int hyperparameter
void set_hyperparameter_value(const char *name, int *hyperparameter, Value *value, VariableSetKind kind,
int default_value, Model *model, HyperparameterValidation *validation)
{
if (kind == VAR_SET_DEFAULT) {
*hyperparameter = default_value;
ereport(NOTICE, (errmsg("Hyperparameter %s takes value DEFAULT (%d)", name, *hyperparameter)));
} else if (kind == VAR_SET_VALUE) {
if (value == NULL) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s cannot take NULL value", name)));
} else if (value->type != T_Integer) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s must be an integer", name)));
}
*hyperparameter = intVal(value);
ereport(NOTICE, (errmsg("Hyperparameter %s takes value %d", name, *hyperparameter)));
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid hyperparameter value for %s ", name)));
}
add_model_hyperparameter(model, name, INT4OID, Int32GetDatum(*hyperparameter));
if (validation->min_value != NULL && validation->max_value != NULL) {
bool out_of_range =
(*hyperparameter < *(int *)validation->min_value || *hyperparameter > *(int *)validation->max_value);
if (!validation->min_inclusive && *hyperparameter <= *(int *)validation->min_value)
out_of_range = true;
if (!validation->max_inclusive && *hyperparameter >= *(int *)validation->max_value)
out_of_range = true;
if (out_of_range) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s must be in the range %c%d,%d%c", name, validation->min_inclusive ? '[' : '(',
*(int *)validation->min_value, *(int *)validation->max_value, validation->max_inclusive ? ']' : ')')));
}
}
}
// Set double hyperparameter
void set_hyperparameter_value(const char *name, double *hyperparameter, Value *value, VariableSetKind kind,
double default_value, Model *model, HyperparameterValidation *validation)
{
if (kind == VAR_SET_DEFAULT) {
*hyperparameter = default_value;
ereport(NOTICE, (errmsg("Hyperparameter %s takes value DEFAULT (%f)", name, *hyperparameter)));
} else if (kind == VAR_SET_VALUE) {
if (value == NULL) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s cannot take NULL value", name)));
} else if (value->type == T_Float) {
*hyperparameter = floatVal(value);
} else if (value->type == T_Integer) {
*hyperparameter = (double)intVal(value);
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s must be a floating point number", name)));
}
ereport(NOTICE, (errmsg("Hyperparameter %s takes value %f", name, *hyperparameter)));
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid hyperparameter value for %s ", name)));
}
add_model_hyperparameter(model, name, FLOAT8OID, Float8GetDatum(*hyperparameter));
if (validation->min_value != NULL && validation->max_value != NULL) {
bool out_of_range =
(*hyperparameter < *(double *)validation->min_value || *hyperparameter > *(double *)validation->max_value);
if (!validation->min_inclusive && *hyperparameter <= *(double *)validation->min_value)
out_of_range = true;
if (!validation->max_inclusive && *hyperparameter >= *(double *)validation->max_value)
out_of_range = true;
if (out_of_range) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s must be in the range %c%.8g,%.8g%c", name,
validation->min_inclusive ? '[' : '(', *(double *)validation->min_value,
*(double *)validation->max_value, validation->max_inclusive ? ']' : ')')));
}
}
}
// Set string hyperparameter (no const)
void set_hyperparameter_value(const char *name, char **hyperparameter, Value *value, VariableSetKind kind,
char *default_value, Model *model, HyperparameterValidation *validation)
{
if (kind == VAR_SET_DEFAULT) {
*hyperparameter = (char*)palloc((strlen(default_value) + 1) * sizeof(char));
errno_t err = strcpy_s(*hyperparameter, strlen(default_value) + 1, default_value);
securec_check(err, "\0", "\0");
ereport(NOTICE, (errmsg("Hyperparameter %s takes value DEFAULT (%s)", name, *hyperparameter)));
} else if (kind == VAR_SET_VALUE) {
if (value == NULL) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s cannot take NULL value", name)));
} else if (value->type != T_String) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s must be a string", name)));
}
*hyperparameter = strVal(value);
ereport(NOTICE, (errmsg("Hyperparameter %s takes value %s", name, *hyperparameter)));
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid hyperparameter value for %s ", name)));
}
if (validation->valid_values != NULL) {
bool found = false;
for (int i = 0; i < validation->valid_values_size; i++) {
if (0 == strcmp(validation->valid_values[i], *hyperparameter)) {
found = true;
break;
}
}
if (!found) {
StringInfo str = makeStringInfo();
for (int i = 0; i < validation->valid_values_size; i++) {
if (i != 0)
appendStringInfoString(str, ", ");
appendStringInfoString(str, (validation->valid_values)[i]);
}
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid hyperparameter value for %s. Valid values are: %s. (default is %s)", name, str->data,
default_value)));
}
}
add_model_hyperparameter(model, name, VARCHAROID, PointerGetDatum(cstring_to_text(*hyperparameter)));
}
inline const char *bool_to_str(bool value)
{
return value ? "TRUE" : "FALSE";
}
// Set boolean hyperparameter
void set_hyperparameter_value(const char *name, bool *hyperparameter, Value *value, VariableSetKind kind,
bool default_value, Model *model, HyperparameterValidation *validation)
{
if (kind == VAR_SET_DEFAULT) {
*hyperparameter = default_value;
ereport(NOTICE, (errmsg("Hyperparameter %s takes value DEFAULT (%s)", name, bool_to_str(*hyperparameter))));
} else if (kind == VAR_SET_VALUE) {
if (value == NULL) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s cannot take NULL value", name)));
} else if (value->type == T_String) {
char *str = strVal(value);
if (strcmp(str, "true") == 0) {
*hyperparameter = true;
} else if (strcmp(str, "false") == 0) {
*hyperparameter = false;
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s is not a valid string for boolean (i.e. 'true' or 'false')", name)));
}
} else if (value->type == T_Integer) {
*hyperparameter = (intVal(value) != 0);
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s must be a boolean or integer", name)));
}
ereport(NOTICE, (errmsg("Hyperparameter %s takes value %s", name, bool_to_str(*hyperparameter))));
} else {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Invalid hyperparameter value for %s ", name)));
}
add_model_hyperparameter(model, name, BOOLOID, BoolGetDatum(*hyperparameter));
}
// Set the hyperparameters according to the definitions. In the process, the range and values of
// each parameter is validated
void configure_hyperparameters(AlgorithmML algorithm, List *hyperparameters, Model *model, void *hyperparameter_struct)
{
HyperparameterDefinition *definitions;
int32_t definitions_size;
get_hyperparameter_definitions(algorithm, &definitions, &definitions_size);
HyperparameterValidation validation;
for (int32_t i = 0; i < definitions_size; i++) {
switch (definitions[i].type) {
case INT4OID: {
int *value_addr = (int *)((char *)hyperparameter_struct + definitions[i].offset);
int value_min = DatumGetInt32(definitions[i].min_value);
int value_max = DatumGetInt32(definitions[i].max_value);
init_hyperparameter_validation(&validation, &value_min, definitions[i].min_inclusive, &value_max,
definitions[i].max_inclusive, NULL, 0);
set_hyperparameter<int>(definitions[i].name, value_addr, hyperparameters,
DatumGetInt32(definitions[i].default_value), model, &validation);
break;
}
case FLOAT8OID: {
double *value_addr = (double *)((char *)hyperparameter_struct + definitions[i].offset);
double value_min = DatumGetFloat8(definitions[i].min_value);
double value_max = DatumGetFloat8(definitions[i].max_value);
init_hyperparameter_validation(&validation, &value_min, definitions[i].min_inclusive, &value_max,
definitions[i].max_inclusive, NULL, 0);
set_hyperparameter<double>(definitions[i].name, value_addr, hyperparameters,
DatumGetFloat8(definitions[i].default_value), model, &validation);
break;
}
case BOOLOID: {
bool *value_addr = (bool *)((char *)hyperparameter_struct + definitions[i].offset);
set_hyperparameter<bool>(definitions[i].name, value_addr, hyperparameters,
DatumGetBool(definitions[i].default_value), model, NULL);
break;
}
case ANYENUMOID: {
void *value_addr = (void *)((char *)hyperparameter_struct + definitions[i].offset);
char *str = NULL;
init_hyperparameter_validation(&validation, NULL, NULL, NULL, NULL, definitions[i].valid_values,
definitions[i].valid_values_size);
set_hyperparameter<char *>(definitions[i].name, &str, hyperparameters,
DatumGetCString(definitions[i].default_value), model, &validation);
definitions[i].enum_setter(str, value_addr);
break;
}
default: {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Invalid hyperparameter OID %d for hyperparameter %s", definitions[i].type,
definitions[i].name)));
}
}
}
}

View File

@ -0,0 +1,22 @@
# kmeans.cmake
set(TGT_kmeans_SRC
${CMAKE_CURRENT_SOURCE_DIR}/kmeans.cpp
)
set(TGT_kmeans_INC
${PROJECT_OPENGS_DIR}/contrib/log_fdw
${PROJECT_TRUNK_DIR}/distribute/bin/kmeanss
${PROJECT_SRC_DIR}/include/libcomm
${PROJECT_SRC_DIR}/include
${PROJECT_SRC_DIR}/lib/gstrace
${LZ4_INCLUDE_PATH}
${LIBCGROUP_INCLUDE_PATH}
${LIBORC_INCLUDE_PATH}
${EVENT_INCLUDE_PATH}
${PROTOBUF_INCLUDE_PATH}
${ZLIB_INCLUDE_PATH}
)
set(kmeans_DEF_OPTIONS ${MACRO_OPTIONS})
set(kmeans_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS} -std=c++14 -fPIE)
list(REMOVE_ITEM kmeans_COMPILE_OPTIONS -fPIC)
set(kmeans_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
add_static_objtarget(gausskernel_db4ai_executor_kmeans TGT_kmeans_SRC TGT_kmeans_INC "${kmeans_DEF_OPTIONS}" "${kmeans_COMPILE_OPTIONS}" "${kmeans_LINK_OPTIONS}")

View File

@ -0,0 +1,23 @@
#---------------------------------------------------------------------------------------
#
# IDENTIFICATION
# src/gausskernel/dbmind/db4ai/executor/kmeans
#
# ---------------------------------------------------------------------------------------
subdir = src/gausskernel/dbmind/db4ai/executor/kmeans
top_builddir = ../../../../../..
include $(top_builddir)/src/Makefile.global
ifneq "$(MAKECMDGOALS)" "clean"
ifneq "$(MAKECMDGOALS)" "distclean"
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
-include $(DEPEND)
endif
endif
endif
OBJS = kmeans.o
include $(top_srcdir)/src/gausskernel/common.mk

View File

@ -0,0 +1,699 @@
/**
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
---------------------------------------------------------------------------------------
kmeans.cpp
k-means operations
IDENTIFICATION
src/gausskernel/dbmind/db4ai/executor/kmeans/kmeans.cpp
---------------------------------------------------------------------------------------
* */
#include <random>
#include "nodes/execnodes.h"
#include "nodes/pg_list.h"
#include "postgres_ext.h"
#include "db4ai/fp_ops.h"
#include "db4ai/distance_functions.h"
#include "db4ai/model_warehouse.h"
#include "db4ai/predict_by.h"
#include "db4ai/db4ai_cpu.h"
/*
* to verify that the pg array we get complies with what we expect
* we discard entries that do not pass the test because we cannot start
* much (consistent) with them anyway
*/
bool verify_pgarray(ArrayType const * pg_array, int32_t n)
{
/*
* We expect the input to be an n-element array of doubles; verify that. We
* don't need to use deconstruct_array() since the array data is just
* going to look like a C array of n double values.
*/
if (unlikely((ARR_NDIM(pg_array) != 1) || (ARR_DIMS(pg_array)[0] != n) || ARR_HASNULL(pg_array) ||
(ARR_ELEMTYPE(pg_array) != FLOAT8OID)))
return false;
return true;
}
/*
* this copies the coordinates found inside a slot onto an array we own
*/
bool copy_slot_coordinates_to_array(GSPoint *coordinates, TupleTableSlot const * slot, uint32_t const dimension)
{
if (unlikely((slot == nullptr) or (coordinates == nullptr)))
return false;
/*
* we obtain the coordinates of the current point (function call incurs in detoasting
* and thus memory is allocated and we have to free it once we don't need it
*/
ArrayType *current_point_pgarray = DatumGetArrayTypePCopy(slot->tts_values[0]);
bool const valid_point = verify_pgarray(current_point_pgarray, dimension);
bool release_point = PointerGetDatum(current_point_pgarray) != slot->tts_values[0];
/*
* if the point is not valid and it was originally toasted, then we release the copy
*/
if (unlikely(!valid_point && release_point)) {
pfree(current_point_pgarray);
release_point = false;
}
coordinates->pg_coordinates = current_point_pgarray;
coordinates->should_free = release_point;
return valid_point;
}
/*
* this sets the weights of a set of candidates to 1 (every point is the centroid of itself)
*/
void reset_weights(List const * centroids)
{
ListCell const * current_centroid_cell = centroids ? centroids->head : nullptr;
GSPoint *centroid = nullptr;
for (; current_centroid_cell != nullptr; current_centroid_cell = lnext(current_centroid_cell)) {
centroid = reinterpret_cast<GSPoint *>(lfirst(current_centroid_cell));
centroid->weight = 1U;
}
}
/*
* given a set of centroids (as a PG list) and a point, this function compute the distance to the closest
* centroid
*/
bool closest_centroid(List const * centroids, GSPoint const * point, uint32_t const dimension, double *distance)
{
ListCell const * current_centroid_cell = centroids ? centroids->head : nullptr;
GSPoint *centroid = nullptr;
GSPoint *closest_centroid_ptr = nullptr;
bool result = false;
bool min_distance_changed = false;
double local_distance = 0.;
auto min_distance = DBL_MAX;
auto const * point_coordinates = reinterpret_cast<double const *>(ARR_DATA_PTR(point->pg_coordinates));
for (; current_centroid_cell != nullptr; current_centroid_cell = lnext(current_centroid_cell)) {
/*
* low temporal locality for a prefetch for read
*/
prefetch(lnext(current_centroid_cell), 0, 1);
centroid = reinterpret_cast<GSPoint *>(lfirst(current_centroid_cell));
local_distance = l2_squared(point_coordinates,
reinterpret_cast<double const *>(ARR_DATA_PTR(centroid->pg_coordinates)), dimension);
min_distance_changed = local_distance < min_distance;
min_distance = min_distance_changed ? local_distance : min_distance;
closest_centroid_ptr = min_distance_changed ? centroid : closest_centroid_ptr;
}
if (closest_centroid_ptr) {
result = true;
++closest_centroid_ptr->weight;
}
if (likely(distance != nullptr))
*distance = min_distance;
return result;
}
/*
* given a set of centroids (as a PG list) and a set of points, this function computes
* the cost of the set of centroids as well as their weights (number of points assigned
* to each centroid)
* we assumed that all weights have been reset already
*/
bool compute_cost_and_weights(List const * centroids, GSPoint const * points, uint32_t dimension,
uint32_t const num_slots, double *cost)
{
GSPoint const * point = nullptr;
double local_distance = 0.;
double cost_of_batch = 0.;
double local_op_error = 0.;
double total_op_error = 0.;
uint32_t current_slot = 0U;
bool const result = current_slot < num_slots;
// for every point we compute the closest centroid and increase the weight of the corresponding centroid
while (current_slot < num_slots) {
point = points + current_slot;
closest_centroid(centroids, point, dimension, &local_distance);
twoSum(cost_of_batch, local_distance, &cost_of_batch, &local_op_error);
total_op_error += local_op_error;
++current_slot;
}
cost_of_batch += total_op_error;
*cost = cost_of_batch;
return result;
}
/*
* this runs kmeans++ on a super-set of centroids to obtain the k centroids we want as seeding
*/
List *kmeanspp(KMeansStateDescription *description, List *centroids_candidates, uint32_t const idx_current_centroids,
uint32_t const size_centroid_bytes, std::mt19937_64 *prng)
{
Centroid *centroids = description->centroids[idx_current_centroids];
uint32_t const num_centroids_needed = description->num_centroids;
uint32_t num_candidates = centroids_candidates ? centroids_candidates->length : 0;
uint32_t const dimension = description->dimension;
ListCell *current_candidate_cell = nullptr;
ListCell *prev_candidate_cell = nullptr;
ListCell *tmp_cell = nullptr;
std::uniform_real_distribution<double> unit_sampler(0., 1.);
// if there are less candidates than centroids needed, then all of them become centroids
double sample_probability = num_candidates <= num_centroids_needed ? 1. : 1. / static_cast<double>(num_candidates);
double sample_probability_correction = 0.;
double candidate_probability = 0.;
double distance = 0.;
double sum_distances = 0.;
double sum_distances_local_correction = 0.;
double sum_distances_correction = 0.;
ArrayType *current_candidate_pgarray = nullptr;
ArrayType *current_centroid_pgarray = nullptr;
GSPoint *current_candidate = nullptr;
GSPoint *tmp_candidate = nullptr;
uint32_t current_centroid_idx = description->current_centroid;
uint32_t tries_until_next_centroid = 0;
bool no_more_candidates = false;
/*
* we expect to produce all centroids in one go and to be able to produce them because
* we have enough candidates
*/
if ((current_centroid_idx > 0) || (num_centroids_needed == 0))
ereport(ERROR,
(errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("k-means: not able to run k-means++")));
/*
* we produce the very first centroid uniformly at random (not weighted)
* the rest of the centroids are found by sampling w.r.t. the (weighted) distance to their closest centroid
* (elements that are far from their centroids are more likely to become centroids themselves,
* but heavier points also have better chances)
*/
while (likely((current_centroid_idx < num_centroids_needed) && !no_more_candidates)) {
++tries_until_next_centroid;
/*
* the probability of no choosing a centroid in a round is > 0
* so we loop until we have found all our k centroids
* in each round one centroid will be found at most
*/
prev_candidate_cell = nullptr;
no_more_candidates = true;
current_candidate_cell = centroids_candidates ? centroids_candidates->head : nullptr;
for (; current_candidate_cell != nullptr; current_candidate_cell = lnext(current_candidate_cell)) {
current_candidate = reinterpret_cast<GSPoint *>(lfirst(current_candidate_cell));
current_candidate_pgarray = current_candidate->pg_coordinates;
no_more_candidates = false;
candidate_probability = unit_sampler(*prng);
/*
* the very first candidate will be sampled uniformly, the rest will be sampled
* w.r.t. the distance to their closest centroid (the farther the more probable)
*/
if (likely(current_centroid_idx > 0)) {
// this distance is already weighted and thus requires no weighting again
distance = current_candidate->distance_to_closest_centroid;
twoDiv(distance, sum_distances, &sample_probability, &sample_probability_correction);
sample_probability += sample_probability_correction;
}
/*
* this is weighted sampling (taking into consideration the weight of the point)
*/
if (candidate_probability >= sample_probability) {
prev_candidate_cell = current_candidate_cell;
continue;
}
/*
* this candidate becomes a centroid
*/
current_centroid_pgarray = centroids[current_centroid_idx].coordinates;
/*
* we copy the coordinates of the centroid to the official set of centroids
*/
if (unlikely((current_centroid_pgarray != nullptr) and (current_candidate_pgarray != nullptr))) {
auto point_coordinates_to = reinterpret_cast<double *>(ARR_DATA_PTR(current_centroid_pgarray));
auto point_coordinates_from = reinterpret_cast<double *>(ARR_DATA_PTR(current_candidate_pgarray));
memset_s(point_coordinates_to, size_centroid_bytes, 0, size_centroid_bytes);
errno_t rc = memcpy_s(point_coordinates_to,
size_centroid_bytes, point_coordinates_from, size_centroid_bytes);
securec_check(rc, "\0", "\0");
}
/*
* we delete the element that just became centroid from the list of candidates (along all its information)
*/
centroids_candidates = list_delete_cell(centroids_candidates, current_candidate_cell, prev_candidate_cell);
pfree(current_candidate->pg_coordinates);
current_candidate_cell =
prev_candidate_cell ? prev_candidate_cell : centroids_candidates ? centroids_candidates->head : nullptr;
/*
* we can reset sum_distances because it will be overwritten below anyway
*/
sum_distances = 0.;
/*
* we update the distance to the closest centroid depending on the presence of the new
* centroid
*/
tmp_cell = centroids_candidates ? centroids_candidates->head : nullptr;
for (; tmp_cell != nullptr; tmp_cell = lnext(tmp_cell)) {
tmp_candidate = reinterpret_cast<GSPoint *>(lfirst(tmp_cell));
distance = l2_squared(reinterpret_cast<double *>(ARR_DATA_PTR(tmp_candidate->pg_coordinates)),
reinterpret_cast<double *>(ARR_DATA_PTR(current_centroid_pgarray)), dimension);
/*
* medium temporal locality for a prefetch for write
*/
prefetch(lnext(tmp_cell), 1, 2);
/*
* since we are dealing with weighted points, the overall sum of distances has
* to consider the weight of each point
*/
twoMult(distance, tmp_candidate->weight, &distance, &sum_distances_correction);
distance += sum_distances_correction;
/*
* we store the weighted distance at the point to save the multiplication later on
* when sampling
* only if the new centroid becomes the closest one to a candidate we update
*/
if ((current_centroid_idx == 0) || (distance < tmp_candidate->distance_to_closest_centroid))
tmp_candidate->distance_to_closest_centroid = distance;
/*
* every distance appears as many times as the weight of the point
*/
twoSum(sum_distances, tmp_candidate->distance_to_closest_centroid, &sum_distances,
&sum_distances_local_correction);
sum_distances_correction += sum_distances_local_correction;
}
/*
* high temporal locality for a prefetch for read
*/
if (likely(current_candidate_cell != nullptr))
prefetch(lnext(current_candidate_cell), 0, 3);
sum_distances += sum_distances_correction;
++current_centroid_idx;
tries_until_next_centroid = 0;
break;
}
}
/*
* we get rid of the linked list of candidates in case some candidates are left
*/
while (centroids_candidates) {
current_candidate_cell = centroids_candidates->head;
current_candidate = reinterpret_cast<GSPoint *>(lfirst(current_candidate_cell));
/*
* low temporal locality for a prefetch for read
*/
prefetch(lnext(current_candidate_cell), 0, 1);
/*
* this frees the list cell
*/
centroids_candidates = list_delete_cell(centroids_candidates, current_candidate_cell, nullptr);
/*
* this frees the PG array inside the GSPoint
*/
pfree(current_candidate->pg_coordinates);
/*
* and finally, this free ths GSPoint
*/
pfree(current_candidate);
}
description->current_centroid = current_centroid_idx;
return centroids_candidates;
}
/*
* this aggregates the value of the function over all centroids
*/
void compute_cost(KMeansStateDescription *description, uint32_t const idx_current_centroids)
{
Centroid *centroids = description->centroids[idx_current_centroids];
uint32_t const num_centroids = description->num_centroids;
IncrementalStatistics *output_statistics = description->solution_statistics + idx_current_centroids;
output_statistics->reset();
for (uint32_t c = 0; c < num_centroids; ++c)
*output_statistics += centroids[c].statistics;
}
/*
* every iteration there is a set of centroids (the next one) that is reset
*/
void reset_centroids(KMeansStateDescription *description, uint32_t const idx_centroids,
uint32_t const size_centroid_bytes)
{
uint32_t const num_centroids = description->num_centroids;
Centroid *centroids = description->centroids[idx_centroids];
Centroid *centroid = nullptr;
double *centroid_coordinates = nullptr;
for (uint32_t c = 0; c < num_centroids; ++c) {
centroid = centroids + c;
centroid->statistics.reset();
centroid_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(centroid->coordinates));
memset_s(centroid_coordinates, size_centroid_bytes, 0, size_centroid_bytes);
}
}
/*
* given the running mean of a centroid and a new point, this adds the new point to the aggregate
* using a sum that provides higher precision (we could provide much higher precision at the cost
* of allocating yet another array to keep correction terms for every dimension
*/
force_inline void aggregate_point(double *centroid_aggregation, double const * new_point, uint32_t const dimension)
{
double local_correction = 0;
for (uint32_t d = 0; d < dimension; ++d) {
twoSum(centroid_aggregation[d], new_point[d], centroid_aggregation + d, &local_correction);
centroid_aggregation[d] += local_correction;
}
}
/*
* this produces the centroid by dividing the aggregate by the amount of points it got assigned
* we assumed that population > 0
*/
force_inline void finish_centroid(double *centroid_aggregation, uint32_t const dimension, double const population)
{
double local_correction = 0.;
for (uint32_t d = 0; d < dimension; ++d) {
twoDiv(centroid_aggregation[d], population, centroid_aggregation + d, &local_correction);
centroid_aggregation[d] += local_correction;
}
}
/*
* updates the minimum bounding box to contain the new given point
*/
force_inline void update_bbox(double * const bbox_min, double * const bbox_max, double const * point,
uint32_t const dimension)
{
uint32_t current_dimension = 0;
double min = 0.;
double max = 0.;
double p = 0.;
while (current_dimension < dimension) {
min = bbox_min[current_dimension];
max = bbox_max[current_dimension];
p = point[current_dimension];
/*
* we do cmovs instead of ifs (we could spare a comparison from time to time
* by using ifs, but we increase branch misprediction as well. thus we settle
* for the branchless option (more expensive than a hit, but cheaper than a miss)
*/
bbox_min[current_dimension] = p < min ? p : min;
bbox_max[current_dimension] = p > max ? p : max;
++current_dimension;
}
}
bool init_kmeans(KMeansStateDescription *description, double *bbox_min, double *bbox_max, GSPoint const * batch,
uint32_t const num_slots, uint32_t const size_centroid_bytes)
{
uint32_t const dimension = description->dimension;
bool const first_run = description->current_iteration == 0;
uint32_t current_slot = 0;
GSPoint const * current_point = nullptr;
double const * current_point_coordinates = nullptr;
if (unlikely(num_slots == 0))
return false;
/*
* the very first slot of the batch is a bit special. observe that we have got here
* after the point has passed validity checks and thus it is safe to access its information
* directly
*/
current_point = batch + current_slot;
current_point_coordinates = reinterpret_cast<double const *>(ARR_DATA_PTR(current_point->pg_coordinates));
/*
* in the very first run we set the coordinates of the bounding box as the ones
* of the very first point (we improve from there)
*/
if (unlikely(first_run)) {
/*
* no need to memset to zero because in the very first run they are freshly allocated
* with palloc0
*/
errno_t rc = memcpy_s(bbox_min, size_centroid_bytes, current_point_coordinates, size_centroid_bytes);
securec_check(rc, "\0", "\0");
rc = memcpy_s(bbox_max, size_centroid_bytes, current_point_coordinates, size_centroid_bytes);
securec_check(rc, "\0", "\0");
description->current_iteration = 1;
} else {
update_bbox(bbox_min, bbox_max, current_point_coordinates, dimension);
}
++description->num_good_points;
/*
* let's consider the rest of the batch
*/
while (likely(++current_slot < num_slots)) {
current_point = batch + current_slot;
current_point_coordinates = reinterpret_cast<double const *>(ARR_DATA_PTR(current_point->pg_coordinates));
update_bbox(bbox_min, bbox_max, current_point_coordinates, dimension);
++description->num_good_points;
}
/*
* done with the batch
*/
return true;
}
/*
* we assume that all slots in the batch are non-null (guaranteed by the upper call)
* also, that the next set of centroids has been reset previous to the very first call
*/
void update_centroids(KMeansStateDescription *description, GSPoint *slots, uint32_t const num_slots,
uint32_t const idx_current_centroids, uint32_t const idx_next_centroids)
{
uint32_t const dimension = description->dimension;
uint32_t current_slot = 0U;
uint32_t current_centroid = 0U;
uint32_t const num_centroids = description->num_centroids;
uint32_t closest_centroid = 0U;
GSPoint const * current_point = nullptr;
double const * current_point_coordinates = nullptr;
double *current_centroid_coordinates = nullptr;
double *next_centroid_coordinates = nullptr;
double dist = 0.;
auto min_dist = DBL_MAX;
Centroid *current_centroids = description->centroids[idx_current_centroids];
Centroid *next_centroids = description->centroids[idx_next_centroids];
Centroid *closest_centroid_ptr = nullptr;
Centroid *closest_centroid_next_ptr = nullptr;
bool min_dist_change = false;
IncrementalStatistics local_statistics;
/*
* just in case, but this should not happen as we control the parent call
*/
if (unlikely(num_slots == 0))
return;
do {
current_centroid = 0U;
min_dist = DBL_MAX;
/*
* we obtain the coordinates of the current point
*/
current_point = slots + current_slot;
/*
* this loops obtains the distance of the current point to all centroids and keeps the closest one
*/
while (likely(current_centroid < num_centroids)) {
current_centroid_coordinates =
reinterpret_cast<double *>(ARR_DATA_PTR(current_centroids[current_centroid].coordinates));
current_point_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(current_point->pg_coordinates));
dist = description->distance(current_point_coordinates, current_centroid_coordinates, dimension);
min_dist_change = dist < min_dist;
closest_centroid = min_dist_change ? current_centroid : closest_centroid;
min_dist = min_dist_change ? dist : min_dist;
++current_centroid;
}
/*
* once the closest centroid has been detected we proceed with the aggregation and update
* of statistics
*/
local_statistics.setTotal(min_dist);
local_statistics.setMin(min_dist);
local_statistics.setMax(min_dist);
local_statistics.setPopulation(1ULL);
closest_centroid_ptr = current_centroids + closest_centroid;
closest_centroid_next_ptr = next_centroids + closest_centroid;
closest_centroid_ptr->statistics += local_statistics;
/*
* for the next iteration (if there is any) we have to obtain a new centroid, which will be
* the average of the points that we aggregate here
*/
next_centroid_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(closest_centroid_next_ptr->coordinates));
aggregate_point(next_centroid_coordinates, current_point_coordinates, dimension);
} while (likely(++current_slot < num_slots));
}
void merge_centroids(KMeansStateDescription *description, uint32_t const idx_current_centroids,
uint32_t const idx_next_centroids, uint32_t const size_centroid_bytes)
{
uint32_t const num_centroids = description->num_centroids;
uint32_t const dimension = description->dimension;
Centroid * const current_centroids = description->centroids[idx_current_centroids];
Centroid * const next_centroids = description->centroids[idx_next_centroids];
Centroid *current_centroid = nullptr;
Centroid *next_centroid = nullptr;
double *current_centroid_coordinates = nullptr;
double *next_centroid_coordinates = nullptr;
for (uint32_t c = 0; c < num_centroids; ++c) {
next_centroid = next_centroids + c;
current_centroid = current_centroids + c;
next_centroid_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(next_centroid->coordinates));
/*
* if the cluster of a centroid is empty we copy it from the previous iteration verbatim
* to not loose it
*/
if (unlikely(current_centroid->statistics.getPopulation() == 0)) {
current_centroid_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(current_centroid->coordinates));
/*
* observe that we do not have to memset to zero because current_centroid was reset before the run
* and since no point was assigned to it, it has remained reset
*/
error_t rc = memcpy_s(next_centroid_coordinates, size_centroid_bytes, current_centroid_coordinates,
size_centroid_bytes);
securec_check(rc, "\0", "\0");
} else {
finish_centroid(next_centroid_coordinates, dimension, current_centroid->statistics.getPopulation());
}
}
}
bool finish_kmeans()
{
return true;
}
ModelPredictor kmeans_predict_prepare(Model const * model)
{
return reinterpret_cast<ModelPredictor>(const_cast<Model *>(model));
}
Datum kmeans_predict(ModelPredictor model, Datum *data, bool *nulls, Oid *types, int32_t nargs)
{
auto kmeans_model = reinterpret_cast<ModelKMeans *>(model);
/*
* sanity checks
*/
if (unlikely(nargs != 1))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("k-means predict: only a single attribute containing the coordinates is accepted")));
if (unlikely(types[0] != FLOAT8ARRAYOID))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("k-means predict: only double precision array of coordinates is accepted")));
if (unlikely(nulls[0]))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("k-means predict: array of coordinates cannot be null")));
uint32_t const num_centroids = kmeans_model->actual_num_centroids;
uint32_t const dimension = kmeans_model->dimension;
double (*distance)(double const *, double const *, uint32_t const) = nullptr;
if (unlikely(num_centroids == 0))
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("k-means predict: number of centroids must be positive")));
switch (kmeans_model->distance_function_id) {
case KMEANS_L1:
distance = l1;
break;
case KMEANS_L2:
distance = l2;
break;
case KMEANS_L2_SQUARED:
distance = l2_squared;
break;
case KMEANS_LINF:
distance = linf;
break;
default:
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("k-means predict: distance function id %u not recognized", kmeans_model->distance_function_id)));
}
WHCentroid *current_centroid = nullptr;
WHCentroid *closest_centroid = nullptr;
auto input_point_pg_array = DatumGetArrayTypeP(data[0]);
auto min_distance = DBL_MAX;
double local_distance = 0.;
double const * input_point_coordinates = nullptr;
int32_t closest_centroid_id = -1;
bool const valid_input = verify_pgarray(input_point_pg_array, dimension);
bool min_distance_changed = false;
if (unlikely(!valid_input))
return Int32GetDatum(closest_centroid_id);
input_point_coordinates = reinterpret_cast<double const *>(ARR_DATA_PTR(input_point_pg_array));
for (uint32_t c = 0; c < num_centroids; ++c) {
current_centroid = kmeans_model->centroids + c;
local_distance = distance(input_point_coordinates, current_centroid->coordinates, dimension);
min_distance_changed = local_distance < min_distance;
min_distance = min_distance_changed ? local_distance : min_distance;
closest_centroid = min_distance_changed ? current_centroid : closest_centroid;
}
closest_centroid_id = closest_centroid->id;
/*
* for the time being there is no other way to get to the computed distance other than by a log
*/
return Int32GetDatum(closest_centroid_id);
}

View File

@ -0,0 +1,318 @@
# *DB4AI Snapshots* for relational dataset versioning
## 1. Introduction
The module *DB4AI Snapshots* provides a robust and efficient framework that allows any database user to manage multiple versions of relational datasets through a convenient API.
### 1.1 Compact storage with efficient access
Its main benefits are automated management of dependencies among large, versioned datasets and their compact storage. Therefore, *DB4AI Snapshots* leverages redundancies among different versions of data for providing a high level of compaction while minimizing adverse effects on performance. At the same time, snapshot data remains efficiently accessible using the full expressiveness of standard SQL.
### 1.2 Immutable snapshot data
The snapshot API prevents the user from changing versioned data. Similar to a code repository, every change of data will generate a new version, i.e. a new snapshot.
### 1.3 Performance
In addition to compact data representation, the primary goal of *DB4AI Snapshots* is high read performance, i.e. when used in highly repetitive and concurrent read operations for serving as training data sets for concurrent training of multiple ML models.
As a secondary performance goal, *DB4AI Snapshots* provides efficient data manipulation for creating and manipulation huge volumes of versioned relational data.
### 1.4 Documentation and Automation
In addition, *DB4AI Snapshots* maintains a full documentation of the origin of any dataset, providing lineage and provenance for data, e.g. to be used in the context of reproducible training of ML models. In addition, *DB4AI Snapshots* facilitates automation, i.e. when applying repetitive transformation steps in data cleansing, or for automatically updating an existing training dataset with new data.
## 2. Quick start
### 2.1 Setup
*DB4AI Snapshots* is automatically installed in every new database instance of openGauss.
Therefore the CREATE DATABASE procedure creates the database schema *db4ai* within the new database and populates it with objects required for managing snapshot data.
After successful database creation, any user may start exploring the snapshot functionality. No additional privileges are required.
### 2.2 Examples
Set snapshot mode to compact storage model *CSS* (**C**omputed **S**nap**S**hot mode).
SET db4ai_snapshot_mode = CSS;
Create a snapshot 'm0@1.0.0' from existing data, where stored in table 'test_data'. The SQL statement may use arbitrary joins and mappings for defining the snapshot.
CREATE SNAPSHOT m0 AS
SELECT a1, a3, a6, a8, a2, pk, a1 b9, a7, a5 FROM test_data;
Create snapshot 'm0@2.0.0' from snapshot 'm0@1.0.0' by applying arbitrary DML and DDL statements.
The new version number indicates a snapshot schema revision by means of at least one ALTER statement.
CREATE SNAPSHOT m0 FROM @1.0.0 USING (
UPDATE SNAPSHOT SET a1 = 5 WHERE pk % 10 = 0;
ALTER SNAPSHOT ADD " u+ " INTEGER, ADD "x <> y"INT DEFAULT 2, ADD t CHAR(10) DEFAULT '',
DROP a2, DROP COLUMN IF EXISTS b9, DROP COLUMN IF EXISTS b10;
UPDATE SNAPSHOT SET "x <> y" = 8 WHERE pk < 100;
ALTER SNAPSHOT DROP " u+ ", DROP IF EXISTS " x+ ";
DELETE FROM SNAPSHOT WHERE pk = 3
);
Create snapshot 'm0@2.0.1' from snapshot 'm0@2.0.0' by UPDATE while using a reference to another table. The new version number indicates an update operation (minor data patch). This example uses an AS clause for introducing 'i' as the snapshot's custom correlation name for joining with tables during the UPDATE operation.
CREATE SNAPSHOT m0 FROM @2.0.0 USING (
UPDATE SNAPSHOT AS i SET a5 = o.a2 FROM test_data o
WHERE i.pk = o.pk AND o.a3 % 8 = 0
);
Create snapshot 'm0@2.1.0' from snapshot 'm0@2.0.1' by DELETE while using a reference to another table. The new version number indicates a data revision. This example uses the snapshot's default correlation name 'SNAPSHOT' for joining with another table.
CREATE SNAPSHOT m0 FROM @2.0.1 USING (
DELETE FROM SNAPSHOT USING test_data o
WHERE SNAPSHOT.pk = o.pk AND o.A7 % 2 = 0
);
Create snapshot 'm0@2.2.0' from snapshot 'm0@2.1.0' by inserting new data. The new version number indicates another data revision.
CREATE SNAPSHOT m0 FROM @2.1.0 USING (
INSERT INTO SNAPSHOT SELECT a1, a3, a6, a8, a2, pk+1000 pk, a7, a5, a4
FROM test_data WHERE pk % 10 = 4
);
The SQL syntax was extended with the new @ operator in relation names, allowing the user to specify a snapshot with version throughout SQL. Internally, snapshots are stored as views, where the actual name is generated according to GUC parameters on arbitrary level, e.g. on database level using the current setting of the version delimiter:
-- DEFAULT db4ai_snapshot_version_delimiter IS '@'
ALTER DATABASE <name> SET db4ai_snapshot_version_delimiter = '#';
Similarly the version separator can be changed by the user:
-- DEFAULT db4ai_snapshot_version_separator IS .
ALTER DATABASE <name> SET db4ai_snapshot_version_separator = _;
Independently from the GUC parameter settings mentioned above, any snapshot version can be accessed:
-- standard version string @schema.revision.patch:
SELECT * FROM public.data@1.2.3;
-- user-defined version strings:
SELECT * FROM accounting.invoice@2021;
SELECT * FROM user.cars@cleaned;
-- quoted identifier for blanks, keywords, special characters, etc.:
SELECT * FROM user.cars@"rev 1.1";
-- or string literal:
SELECT * FROM user.cars@'rev 1.1';
Alternative, using internal name (depends on GUC settings):
-- With internal view name, using default GUC settings
SELECT * FROM public."data@1.2.3";
-- With internal view name, using custom GUC settings, as above
SELECT * FROM public.data#1_2_3;
## 3. Privileges
All members of role **PUBLIC** may use **DB4AI Snapshots**.
## 4. Supported Systems and SQL compatibility modes
The current version of *DB4AI Snapshots* is tested in openGauss SQL compatibility modes A, B and C.
## 5. Portability
*DB4AI Snapshots* uses standard SQL for implementing its functionality.
## 6. Dependencies
None.
## 7. Reference & Documentation
### 7.1 Configuration parameters
**DB4AI Snapshots** exposes several configuration parameters, via the system's global unified configuration (GUC) management.
Configuration parameters may be set on the scope of functions (CREATE FUNCTION), transactions (SET LOCAL), sessions (SET),
user (ALTER USER), database (ALTER DATABASE), or on system-wide scope (postgresql.conf).
SET [SESSION | LOCAL] configuration_parameter { TO | = } { value | 'value' }
CREATE FUNCTION <..> SET configuration_parameter { TO | = } { value | 'value' }
ALTER DATABASE name SET configuration_parameter { TO | = } { value | 'value' }
ALTER USER name [ IN DATABASE database_name ] SET configuration_parameter { TO | = } { value | 'value' }
The following snapshot configuration parameters are currently supported:
#### db4ai_snapshot_mode = { MSS | CSS }
This snapshot configuration parameter allows to switch between materialized snapshot (*MSS*) mode, where every new snapshot is created as compressed but fully
materialized copy of its parent's data, or computed snapshot (*CSS*) mode. In *CSS* mode, the system attempts to exploit redundancies among dependent snapshot versions
for minimizing storage requirements.
The setting of *db4ai_snapshot_mode* may be adjusted at any time and it will have effect on subsequent snapshot operations within the scope of the new setting.
Whenever *db4ai_snapshot_mode* is not set in the current scope, it defaults to *MSS*.
##### Example
SET db4ai_snapshot_mode = CSS;
#### db4ai_snapshot_version_delimiter = value
This snapshot configuration parameter controls the character that delimits the *snapshot version* postfix within snapshot names.
In consequence, the character used as *db4ai_snapshot_version_delimiter* cannot be used in snapshot names, neither in the snapshot
name prefix, nor in the snapshot version postfix. Also, the setting of *db4ai_snapshot_version_delimiter* must be distinct from
*db4ai_snapshot_version_separator*. Whenever *db4ai_snapshot_version_delimiter* is not set in the current scope, it defaults to
the symbol '@' (At-sign).
*Note:* Snapshots created with different settings of *db4ai_snapshot_version_delimiter* are not compatible among each
other. Hence, it is advisable to ensure the setting is stable, i.e. by setting it permanently, e.g. on database scope.
##### Example
ALTER DATABASE name SET db4ai_snapshot_version_delimiter = '#';
#### db4ai_snapshot_version_separator = value
This snapshot configuration parameter controls the character that separates the *snapshot version* within snapshot names. In consequence,
*db4ai_snapshot_version_separator* must not be set to any character representing a digit [0-9].
Also, the setting of *db4ai_snapshot_version_separator* must be distinct from *db4ai_snapshot_version_delimiter*.
Whenever *db4ai_snapshot_version_separator* is not set in the current scope, it defaults to punctuation mark '.' (period).
*Note:* Snapshots created with different settings of *db4ai_snapshot_version_separator* do not support automatic version number
generation among each other. Hence, it is advisable to ensure the setting is stable, i.e. by setting it permanently, e.g. on database scope.
##### Example
ALTER DATABASE name SET db4ai_snapshot_version_separator = '_';
### 7.2 Accessing a snapshot
Independently from the GUC parameter settings mentioned above, any snapshot version can be accessed:
SELECT … FROM […,]
<snapshot_qualified_name> @ <vconst | ident | sconst>
[WHERE …] [GROUP BY …] [HAVING …] [ORDER BY …];
Alternative, using standard version string as internal name (depends on GUC settings):
SELECT … FROM […,] <snapshot_qualified_name> INTEGER
<db4ai_snapshot_version_delimiter> INTEGER
<db4ai_snapshot_version_delimiter> INTEGER
[WHERE …] [GROUP BY …] [HAVING …] [ORDER BY …];
Alternative, using user-defined version string as internal name (depends on GUC settings):
SELECT … FROM […,]
<snapshot_qualified_name> <db4ai_snapshot_version_delimiter> <snapshot_version_string>
[WHERE …] [GROUP BY …] [HAVING …] [ORDER BY …];
If any component of <snapshot_qualified_name>, <db4ai_snapshot_version_delimiter>, <db4ai_snapshot_version_separator>, or <snapshot_version_string> should contain special characters, then quoting of the snapshot name is required.
### 7.3 Creating a snapshot
CREATE SNAPSHOT <qualified_name> [@ <version | ident | sconst>]
[COMMENT IS <sconst>}
AS <SelectStmt>;
The CREATE SNAPSHOT AS statement is invoked for creating a new snapshot. A caller provides the qualified_name for the snapshot to be created. The \<SelectStmt\> defines the content of the new snapshot in SQL. The optional @ operator allows to assign a custom version number or string to the new snapshot. The default version number is @ 1.0.0.
A snapshot may be annotated using the optional COMMENT IS clause.
**Example:**
CREATE SNAPSHOT public.cars AS
SELECT id, make, price, modified FROM cars_table;
The CREATE SNAPSHOT AS statement will create the snapshot 'public.cars@1.0.0' by selecting some columns of all the tuples in relation cars_table, which exists in the operational data store.
The created snapshot's name 'public.cars' is automatically extended with the suffix '@1.0.0' to the full snapshot name 'public.cars@1.0.0', thereby creating a unique, versioned identifier for the snapshot.
The DB4AI module of openGauss stores metadata associated with snapshots in a DB4AI catalog table *db4ai.snapshot*. The catalog exposes various metadata about the snapshot, particularly noteworthy is the field 'snapshot_definition' that provides documentation how the snapshot was generated. The DB4AI catalog serves for managing the life cycle of snapshots and allows exploring available snapshots in the system.
In summary, an invocation of the CREATE SNAPSHOT AS statement will create a corresponding entry in the DB4AI catalog, with a unique snapshot name and documentation of the snapshot's lineage. The new snapshot is in state 'published'. Initial snapshots serve as true and reusable copy of operational data, and as a starting point for subsequent data curation, therefore initial snapshots are already immutable. In addition, the system creates a view with the published snapshot's name, with grantable read-only privileges for the current user. The current user may access the snapshot, using arbitrary SQL statements against this view, or grant read-access privileges to other user for sharing the new snapshot for collaboration. Published snapshots may be used for model training, by using the new snapshot name as input parameter to the *db4ai.train* function of the DB4AI model warehouse. Other users may discover new snapshots by browsing the DB4AI catalog, and if corresponding read access privileges on the snapshot view are granted by the snapshot's creator, collaborative model training using this snapshot as training data can commence.
### 7.4 Creating a snapshot revision
CREATE SNAPSHOT <qualified_name> [@ <version | ident | sconst>]
FROM @ <version | ident | sconst>
[COMMENT IS <sconst>}
USING (
{ INSERT [INTO SNAPSHOT] …
| UPDATE [SNAPSHOT] [AS <alias>] SET … [FROM …] [WHERE …]
| DELETE [FROM SNAPSHOT] [AS <alias>] [USING …] [WHERE …]
| ALTER [SNAPSHOT] { ADD … | DROP … } [, …]
} [; …]
);
The CREATE SNAPSHOT FROM statement serves for creating a modified and immutable snapshot based on an existing snapshot. The parent snapshot is specified by the qualified_name and the version provided in the FROM clause.
The new snapshot is created within the parent's schema and it also inherits the prefix of the parent's name, but without the parent's version number. The statements listed in the USING clause define how the parent snapshot shall be modified by means of a batch of SQL DDL and DML statements, i.e. ALTER, INSERT, UPDATE, and DELETE.
**Examples:**
CREATE SNAPSHOT public.cars FROM @1.0.0 USING (
ALTER ADD year int, DROP make;
INSERT SELECT * FROM cars_table WHERE modified=CURRENT_DATE;
UPDATE SET year=in.year FROM cars_table in WHERE SNAPSHOT.id=in.id;
DELETE WHERE modified<CURRENT_DATE-30
); -- Example with 'short SQL' notation
CREATE SNAPSHOT public.cars FROM @1.0.0 USING (
ALTER SNAPSHOT ADD COLUMN year int, DROP COLUMN make;
INSERT INTO SNAPSHOT SELECT * FROM cars_table WHERE modified=CURRENT_DATE;
UPDATE SNAPSHOT SET year=in.year FROM cars_table in WHERE SNAPHOT.id=in.id;
DELETE FROM SNAPSHOT WHERE modified<CURRENT_DATE-30
}; -- Example with standard SQL syntax
In this example, the new snapshot version starts with the current state of snapshot 'public.cars@1.0.0' and adds a new column 'year' to the 'cars' snapshot, while dropping column 'make' that has become irrelevant for the user. This first example uses the short SQL notation, where the individual statements are provided by the user without explicitly stating the snapshot's correlation name. In addition to this syntax, openGauss also accepts standard SQL statements (second example) which tend to be slightly more verbose. Note that both variants allow the introduction of custom correlation names in UPDATE FROM and DELETE USING statements with the AS clause, e.g. `UPDATE AS s [...] WHERE s.id=in.id);` or `UPDATE SNAPSHOT AS s [...] WHERE s.id=in.id);` in the example above.
The INSERT operation shows an example for pulling fresh data from the operational data store into the new snapshot. The UPDATE operation exemplifies populating the newly added column 'year' with data coming from the operational data store, and finally the DELETE operation demonstrates how to remove obsolete data from a snapshot. The name of the resulting snapshot of this invocation is 'public.cars@2.0.0'. Similar as in CREATE SNAPSHOT AS, the user may override the default version numbering scheme that generates version number '@2.0.0'. The optional COMMENT IS clause allows the user to associate a descriptive textual 'comment' with the unit of work corresponding to this invocation of CREATE SNAPHOT FROM, for improving collaboration and documentation, as well as for change tracking purposes.
Since all snapshots are immutable by definition, CREATE SNAPSHOT FROM creates a separate snapshot 'public.cars@2.0.0', initially as a logical copy of the parent snapshot 'public.cars@1.0.0', and applies changes from the USING clause, which corresponds to an integral unit of work in data curation. Similar to a SQL script, the batch of operations is executed atomically and consecutively on the logical copy of the parent snapshot. The parent snapshot remains immutable and is completely unaffected by these changes.
Additionally, the system automatically records all applied changes in the DB4AI snapshot catalog, such that the catalog contains an accurate, complete, and consistent documentation of all changes applied to any snapshot. The catalog also stores the origin and lineage of the created snapshot as a reference to its parent, making provenance of data fully traceable and, as demonstrated later, serves as a repository of data curation operations for supporting automation in snapshot generation.
The operations themselves allow data scientists to remove columns from the parent snapshot, but also to add and populate new ones, e.g. for the purpose of data annotation. By INSERT, rows may be freely added, e.g. from the operational data source or from other snapshots. Inaccurate or irrelevant data can be deleted, as part of the data cleansing process, regardless of whether the data comes from an immutable parent snapshot or directly from the operational data store. Finally, UPDATE statements allow correction of inaccurate or corrupt data, serve for data imputation of missing data and allow normalization of numeric values to a common scale.
In summary, the CREATE SNAPSHOT FROM statement was designed for supporting the full spectrum of recurring tasks in data curation:
• Data cleansing: Remove or correct irrelevant, inaccurate, or corrupt data
• Data Imputation: Fill missing data
• Labeling & Annotation: add immutable columns with computed values
• Data normalization: Update existing columns to a common scale
• Permutation: Support reordering of data for iterative model training
• Indexing: Support random access for model training
Invoking CREATE SNAPSHOT FROM statement allows multiple users to collaborate concurrently in the process of data curation, where each user may break data curation tasks into a set of CREATE SNAPSHOT FROM operations, to be executed in atomic batches. This form of collaboration is similar to software engineers collaborating on a common code repository, but here the concept is extended to include code and data. One invocation of CREATE SNAPSHOT FROM corresponds to a commit operation in a git repository.
In summary, an invocation of the CREATE SNAPSHOT FROM statement will create a corresponding entry in the DB4AI catalog, with a unique snapshot name and documentation of the snapshot's lineage. The new snapshot remains in state 'unpublished', potentially awaiting further data curation. In addition, the system creates a view with the created snapshot's name, with grantable read-only privileges for the current user. Concurrent calls to CREATE SNAPSHOT FROM are permissive and result in separate new versions originating from the same parent (branches). The current user may access the snapshot using arbitrary, read-only SQL statements against this view, or grant read-access privileges to other user, for sharing the created snapshot and enabling collaboration in data curation. Unpublished snapshots may not participate in model training. Yet, other users may discover unpublished snapshots by browsing the DB4AI catalog, and if corresponding read access privileges on the snapshot view are granted by the snapshot's creator, collaborative data curation using this snapshot can commence.
### 7.5 Sampling snapshots
SAMPLE SNAPSHOT <qualified_name> @ <version | ident | sconst>
[STRATIFY BY attr_list]
{ AS <label> AT RATIO <num> [COMMENT IS <comment>] } [, …]
The SAMPLE SNAPSHOT statement is used to sample data from a given snapshot (original snapshot) into one or more descendant, but independent snapshots (branches), satisfying a condition given under the parameter 'ratio'.
**Example:**
SAMPLE SNAPSHOT public.cars@2.0.0
STRATIFY BY color
AS _train AT RATIO .8,
AS _test AT RATIO .2;
This invocation of SAMPLE SNAPSHOT creates two snapshots from the snapshot 'cars@2.0.0', one designated for ML model training purposes: 'cars_train@2.0.0' and the other for ML model testing: 'cars_test@2.0.0'. Note that descendant snapshots inherit the parent's schema, name prefix and version suffix, while each sample definition provides a name infix for making descendant snapshot names unique. The AT RATIO clause specifies the ratio of tuples qualifying for the resulting snapshots, namely 80% for training and 20% for testing. The STRATIFY BY clause specifies that the fraction of records for each car color (white, black, red…) is the same in all three participating snapshots.
### 7.6 Publishing snapshots
PUBLISH SNAPSHOT <qualified_name> @ <version | ident | sconst>;
Whenever a snapshot is created with the CREATE SNAPSHOT FROM statement, it is initially unavailable for ML model training. Such snapshots allow users to collaboratively apply further changes in manageable units of work, for facilitating cooperation in data curation. A snapshot is finalized by publishing it via the PUBLISH SNAPSHOT statement. Published snapshots may be used for model training, by using the new snapshot name as input parameter to the *db4ai.train* function of the DB4AI model warehouse. Other users may discover new snapshots by browsing the DB4AI catalog, and if corresponding read access privileges on the snapshot view are granted by the snapshot's creator, collaborative model training using this snapshot as training data can commence.
**Example:**
PUBLISH SNAPSHOT public.cars@2.0.0;
Above is an exemplary invocation, publishing snapshot 'public.cars@2.0.0'.
### 7.7 Archiving snapshots
ARCHIVE SNAPSHOT <qualified_name> @ <version | ident | sconst>;
Archiving changes the state of any snapshot to 'archived', while the snapshot remains immutable and cannot participate in CREATE SNAPSHOT FROM or *db4ai.train* operations. Archived snapshots may be purged, permanently deleting their data and recovering occupied storage space, or they may be reactivated by invoking PUBLISH SNAPSHOT an archived snapshot.
**Example:**
ARCHIVE SNAPSHOT public.cars@2.0.0;
The example above archives snapshot 'public.cars@2.0.0' that was previously in state 'published' or 'unpublished'.
### 7.8 Purging snapshots
PURGE SNAPSHOT <qualified_name> @ <version | ident | sconst>;
The PURGE SNAPSHOT statement is used to permanently delete all data associated with a snapshot from the system. A prerequisite to purging is that the snapshot is not referenced by any existing trained model in the DB4AI model warehouse. Snapshots still referenced by trained models cannot be purged.
Purging snapshots without existing descendant snapshots, removes them completely and occupied storage space is recovered. If descendant snapshots exist, the purged snapshot will be merged into adjacent snapshots, such that no information on lineage is lost, but storage efficiency is improved. In any case, the purged snapshot's name becomes invalid and is removed from the system.
**Example:**
PURGE SNAPSHOT public.cars@2.0.0;
The example above recovers storage space occupied by 'public.cars@2.0.0' by removing the snapshot completely.

View File

@ -0,0 +1,475 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* create.sql
* Create DB4AI.Snapshot functionality.
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/snapshots/create.sql
*
* -------------------------------------------------------------------------
*/
CREATE OR REPLACE FUNCTION db4ai.create_snapshot_internal(
IN s_id BIGINT, -- snapshot id
IN i_schema NAME, -- snapshot namespace
IN i_name NAME, -- snapshot name
IN i_commands TEXT[], -- commands defining snapshot data and layout
IN i_comment TEXT, -- snapshot description
IN i_owner NAME -- snapshot owner
)
RETURNS VOID LANGUAGE plpgsql SECURITY DEFINER SET search_path = pg_catalog, pg_temp
AS $$
DECLARE
e_stack_act TEXT; -- current stack for validation
dist_cmd TEXT; -- DISTRIBUTE BY translation for backing table
row_count BIGINT; -- number of rows in this snapshot
BEGIN
BEGIN
RAISE EXCEPTION 'SECURITY_STACK_CHECK';
EXCEPTION WHEN OTHERS THEN
GET STACKED DIAGNOSTICS e_stack_act = PG_EXCEPTION_CONTEXT;
IF CURRENT_SCHEMA = 'db4ai' THEN
e_stack_act := replace(e_stack_act, 'ion cre', 'ion db4ai.cre');
END IF;
IF e_stack_act NOT LIKE E'referenced column: create_snapshot_internal\n'
'SQL statement "SELECT db4ai.create_snapshot_internal(s_id, i_schema, i_name, i_commands, i_comment, CURRENT_USER)"\n'
'PL/pgSQL function db4ai.create_snapshot(name,name,text[],name,text) line 269 at PERFORM%'
THEN
RAISE EXCEPTION 'direct call to db4ai.create_snapshot_internal(bigint,name,name,text[],text,name) is not allowed'
USING HINT = 'call public interface db4ai.create_snapshot instead';
END IF;
END;
IF length(i_commands[3]) > 0 THEN
<<translate_dist_by_hash>>
DECLARE
pattern TEXT; -- current user column name
mapping NAME[]; -- mapping user column names to internal backing columns
quoted BOOLEAN := FALSE; -- inside quoted identifier
cur_ch VARCHAR; -- current character in tokenizer
idx INTEGER := 0; -- loop counter, cannot use FOR .. iterator
tokens TEXT; -- user's column name list in DISTRIBUTE BY HASH()
BEGIN
-- extract mapping from projection list for view definition
mapping := array(SELECT unnest(ARRAY[ m[1], coalesce(m[2], replace(m[3],'""','"'))]) FROM regexp_matches(
i_commands[5], 't[0-9]+\.(f[0-9]+) AS (?:([^\s",]+)|"((?:[^"]*"")*[^"]*)")', 'g') m);
-- extract field list from DISTRIBUTE BY clause
tokens :=(regexp_matches(i_commands[3], '^\s*DISTRIBUTE\s+BY\s+HASH\s*\((.*)\)\s*$', 'i'))[1];
IF tokens IS NULL OR tokens SIMILAR TO '\s*' THEN
tokens := (regexp_matches(i_commands[3], '^\s*DISTRIBUTE\s+BY\s+REPLICATION\s*$', 'i'))[1];
IF tokens IS NULL OR tokens SIMILAR TO '\s*' THEN
RAISE EXCEPTION 'cannot match DISTRIBUTE BY clause'
USING HINT = 'currently only DISTRIBUTE BY REPLICATION and DISTRIBUTE BY HASH(column_name [, ...]) supported';
END IF;
-- no translation required, bail out
dist_cmd := ' ' || i_commands[3];
EXIT translate_dist_by_hash;
END IF;
tokens := tokens || ' ';
-- prepare the translated command
dist_cmd = ' DISTRIBUTE BY HASH(';
-- BEGIN tokenizer code for testing
pattern := '';
LOOP
idx := idx + 1;
cur_ch := substr(tokens, idx, 1);
EXIT WHEN cur_ch IS NULL OR cur_ch = '';
CASE cur_ch
WHEN '"' THEN
IF quoted AND substr(tokens, idx + 1, 1) = '"' THEN
pattern := pattern || '"';
idx := idx + 1;
ELSE
quoted := NOT quoted;
END IF;
IF quoted THEN
CONTINUE;
END IF;
WHEN ',' THEN
IF quoted THEN
pattern := pattern || cur_ch;
CONTINUE;
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
pattern := ',';
ELSE
idx := idx - 1; -- reset on comma for next loop
END IF;
WHEN ' ', E'\n', E'\t' THEN
IF quoted THEN
pattern := pattern || cur_ch;
CONTINUE;
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
CONTINUE;
END IF;
ELSE
pattern := pattern || CASE WHEN quoted THEN cur_ch ELSE lower(cur_ch) END;
CONTINUE;
END CASE;
-- END tokenizer code for testing
-- attempt to map the pattern
FOR idx IN 2 .. array_length(mapping, 1) BY 2 LOOP
IF pattern = mapping[idx] THEN
-- apply the mapping
dist_cmd := dist_cmd || mapping[idx-1] || ',';
pattern := NULL;
EXIT;
END IF;
END LOOP;
-- check if pattern was mapped
IF pattern IS NOT NULL THEN
RAISE EXCEPTION 'unable to map field "%" to backing table', pattern;
END IF;
END LOOP;
IF quoted THEN
RAISE EXCEPTION 'unterminated quoted identifier ''%'' at or near: ''%''',
substr(pattern, 1, char_length(pattern)-1), i_commands[3];
END IF;
dist_cmd := rtrim(dist_cmd, ',') || ')';
END;
END IF;
dist_cmd := ''; -- we silently drop DISTRIBUTE_BY
EXECUTE 'CREATE TABLE db4ai.t' || s_id || ' WITH (orientation = column, compression = low)' || dist_cmd
|| ' AS SELECT ' || i_commands[4] || ' FROM _db4ai_tmp_x' || s_id;
EXECUTE 'COMMENT ON TABLE db4ai.t' || s_id || ' IS ''snapshot backing table, root is ' || quote_ident(i_schema)
|| '.' || quote_ident(i_name) || '''';
EXECUTE 'CREATE VIEW db4ai.v' || s_id || ' WITH(security_barrier) AS SELECT ' || i_commands[5] || ', xc_node_id, ctid FROM db4ai.t' || s_id;
EXECUTE 'COMMENT ON VIEW db4ai.v' || s_id || ' IS ''snapshot ' || quote_ident(i_schema) || '.' || quote_ident(i_name)
|| ' backed by db4ai.t' || s_id || CASE WHEN length(i_comment) > 0 THEN ' comment is "' || i_comment || '"' ELSE '' END || '''';
EXECUTE 'GRANT SELECT ON db4ai.v' || s_id || ' TO ' || i_owner || ' WITH GRANT OPTION';
EXECUTE 'SELECT COUNT(*) FROM db4ai.v' || s_id INTO STRICT row_count;
-- store only original commands supplied by user
i_commands := ARRAY[i_commands[1], i_commands[2], i_commands[3]];
INSERT INTO db4ai.snapshot (id, root_id, schema, name, owner, commands, comment, published, row_count)
VALUES (s_id, s_id, i_schema, i_name, i_owner, i_commands, i_comment, TRUE, row_count);
END;
$$;
CREATE OR REPLACE FUNCTION db4ai.create_snapshot(
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
IN i_name NAME, -- snapshot name
IN i_commands TEXT[], -- commands defining snapshot data and layout
IN i_vers NAME DEFAULT NULL, -- override version postfix
IN i_comment TEXT DEFAULT NULL -- snapshot description
)
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
AS $$
DECLARE
s_id BIGINT; -- snapshot id
s_mode VARCHAR(3); -- current snapshot mode
s_vers_del CHAR; -- snapshot version delimiter, default '@'
s_vers_sep CHAR; -- snapshot version separator, default '.'
separation_of_powers TEXT; --current separation of rights
qual_name TEXT; -- qualified snapshot name
command_str TEXT; -- command string
pattern TEXT; -- command pattern for matching
proj_cmd TEXT; -- SELECT clause for create user view command (may be NULL if from_cmd is not NULL)
from_cmd TEXT; -- FROM clause for command (may be NULL if proj_cmd is not NULL)
dist_cmd TEXT; -- DISTRIBUTE BY clause for command (may be NULL)
res db4ai.snapshot_name; -- composite result
BEGIN
-- obtain active message level
BEGIN
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
RAISE INFO 'effective client_min_messages is ''%''', upper(current_setting('db4ai.message_level'));
EXCEPTION WHEN OTHERS THEN
END;
-- obtain database state of separation of rights
BEGIN
separation_of_powers := upper(current_setting('enableSeparationOfDuty'));
EXCEPTION WHEN OTHERS THEN
separation_of_powers := 'OFF';
END;
IF separation_of_powers NOT IN ('ON', 'OFF') THEN
RAISE EXCEPTION 'Uncertain state of separation of rights.';
ELSIF separation_of_powers = 'ON' THEN
RAISE EXCEPTION 'Snapshot is not supported in separation of rights';
END IF;
-- obtain active snapshot mode
BEGIN
s_mode := upper(current_setting('db4ai_snapshot_mode'));
EXCEPTION WHEN OTHERS THEN
s_mode := 'MSS';
END;
IF s_mode NOT IN ('CSS', 'MSS') THEN
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
END IF;
-- obtain relevant configuration parameters
BEGIN
s_vers_del := current_setting('db4ai_snapshot_version_delimiter');
EXCEPTION WHEN OTHERS THEN
s_vers_del := '@';
END;
BEGIN
s_vers_sep := current_setting('db4ai_snapshot_version_separator');
EXCEPTION WHEN OTHERS THEN
s_vers_sep := '.';
END;
-- check all input parameters
IF i_schema IS NULL OR i_schema = '' THEN
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
END IF;
IF i_name IS NULL OR i_name = '' THEN
RAISE EXCEPTION 'i_name cannot be NULL or empty';
ELSIF strpos(i_name, s_vers_del) > 0 THEN
RAISE EXCEPTION 'i_name must not contain ''%'' characters', s_vers_del;
END IF;
-- PG BUG: array_ndims('{}') or array_dims(ARRAY[]::INT[]) returns NULL
IF i_commands IS NULL OR array_length(i_commands, 1) IS NULL OR array_length(i_commands, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_commands array malformed'
USING HINT = 'pass SQL commands as TEXT[] literal, e.g. ''{SELECT *, FROM public.t, DISTRIBUTE BY HASH(id)''';
END IF;
FOREACH command_str IN ARRAY i_commands LOOP
IF command_str IS NULL THEN
RAISE EXCEPTION 'i_commands array contains NULL values';
END IF;
END LOOP;
FOREACH command_str IN ARRAY i_commands LOOP
command_str := btrim(command_str);
pattern := upper(regexp_replace(left(command_str, 30), '\s+', ' ', 'g'));
IF left(pattern, 7) = 'SELECT ' THEN
IF proj_cmd IS NULL THEN
proj_cmd := command_str;
DECLARE
nested INT := 0; -- level of nesting
quoted BOOLEAN := FALSE; -- inside quoted identifier
cur_ch VARCHAR; -- current character in tokenizer
idx INTEGER := 0; -- loop counter, cannot use FOR .. iterator
start INTEGER := 1;
stmt TEXT := command_str;
BEGIN
-- BEGIN splitter code for testing
pattern := '';
LOOP
idx := idx + 1;
cur_ch := substr(stmt, idx, 1);
EXIT WHEN cur_ch IS NULL OR cur_ch = '';
CASE cur_ch
WHEN '"' THEN
IF quoted AND substr(stmt, idx + 1, 1) = '"' THEN
idx := idx + 1;
ELSE
quoted := NOT quoted;
END IF;
IF quoted THEN
CONTINUE;
END IF;
WHEN '(' THEN
nested := nested + 1;
CONTINUE;
WHEN ')' THEN
nested := nested - 1;
IF nested < 0 THEN
RAISE EXCEPTION 'syntax error at or near '')'' in ''%'' at position ''%''', stmt, idx;
END IF;
CONTINUE;
WHEN ' ' THEN
IF quoted OR nested > 0 THEN
CONTINUE;
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
start := idx;
CONTINUE;
END IF;
WHEN ';' THEN
RAISE EXCEPTION 'syntax error at or near '';'' in ''%'' at position ''%''', stmt, idx;
CONTINUE;
ELSE
pattern := pattern || upper(cur_ch);
CONTINUE;
END CASE;
-- END splitter code for testing
IF pattern = 'FROM' THEN
from_cmd := substr(stmt, start + 1);
proj_cmd := left(stmt, start - 1);
stmt := from_cmd;
nested := 0;
quoted := FALSE;
idx := idx - start;
start := 1;
RAISE NOTICE E'SELECT SPLITTING1\n%\n%\n%', stmt, proj_cmd, from_cmd;
ELSIF pattern = 'DISTRIBUTE' THEN
RAISE NOTICE E'SELECT SPLITTING2\n%\n%\n%', stmt, proj_cmd, from_cmd;
CONTINUE;
ELSIF pattern = 'DISTRIBUTEBY' THEN
dist_cmd := substr(stmt, start + 1);
from_cmd := left(stmt, start - 1);
RAISE NOTICE E'SELECT SPLITTING3\n%\n%\n%\n%', stmt, proj_cmd, from_cmd, dist_cmd;
EXIT;
END IF;
pattern := '';
start := idx;
END LOOP;
END;
ELSE
RAISE EXCEPTION 'multiple SELECT clauses in i_commands: ''%'' ''%''', proj_cmd, command_str;
END IF;
ELSIF left(pattern, 5) = 'FROM ' THEN
IF from_cmd IS NULL THEN
from_cmd := command_str;
ELSE
RAISE EXCEPTION 'multiple FROM clauses in i_commands: ''%'' ''%''', from_cmd, command_str;
END IF;
ELSIF left(pattern, 14) = 'DISTRIBUTE BY ' THEN
IF dist_cmd IS NULL THEN
dist_cmd := command_str;
ELSE
RAISE EXCEPTION 'multiple DISTRIBUTE BY clauses in i_commands: ''%'' ''%''', dist_cmd, command_str;
END IF;
ELSE
RAISE EXCEPTION 'unrecognized command in i_commands: ''%''', command_str;
END IF;
END LOOP;
IF proj_cmd IS NULL THEN
-- minimum required input
IF from_cmd IS NULL THEN
RAISE EXCEPTION 'SELECT and FROM clauses are missing in i_commands';
END IF;
-- supply default projection
proj_cmd := 'SELECT *';
ELSE
IF from_cmd IS NULL AND strpos(upper(proj_cmd), 'FROM ') = 0 THEN
RAISE EXCEPTION 'FROM clause is missing in i_commands';
END IF;
END IF;
IF dist_cmd IS NULL THEN
dist_cmd := '';
END IF;
IF i_vers IS NULL OR i_vers = '' THEN
i_vers := s_vers_del || '1' || s_vers_sep || '0' || s_vers_sep || '0';
ELSE
i_vers := replace(i_vers, chr(2), s_vers_sep);
IF LEFT(i_vers, 1) <> s_vers_del THEN
i_vers := s_vers_del || i_vers;
ELSIF char_length(i_vers ) < 2 THEN
RAISE EXCEPTION 'illegal i_vers: ''%''', s_vers_del;
END IF;
IF strpos(substr(i_vers, 2), s_vers_del) > 0 THEN
RAISE EXCEPTION 'i_vers may contain only one single, leading ''%'' character', s_vers_del
USING HINT = 'specify snapshot version as [' || s_vers_del || ']x' || s_vers_sep || 'y' || s_vers_sep || 'z or ['
|| s_vers_del || ']label with optional, leading ''' || s_vers_del || '''';
END IF;
END IF;
IF char_length(i_name || i_vers) > 63 THEN
RAISE EXCEPTION 'snapshot name too long: ''%''', i_name || i_vers;
ELSE
i_name := i_name || i_vers;
END IF;
-- the final name of the snapshot
qual_name := quote_ident(i_schema) || '.' || quote_ident(i_name);
-- check for duplicate snapshot
IF 0 < (SELECT COUNT(*) FROM db4ai.snapshot WHERE schema = i_schema AND name = i_name) THEN
RAISE EXCEPTION 'snapshot % already exists' , qual_name;
END IF;
--SELECT nextval('db4ai.snapshot_sequence') INTO STRICT s_id;
SELECT COALESCE(MAX(id)+1,0) FROM db4ai.snapshot INTO STRICT s_id; -- openGauss BUG: cannot create sequences in initdb
-- execute using current user privileges
DECLARE
e_message TEXT; -- exception message
BEGIN
EXECUTE 'CREATE TEMPORARY TABLE _db4ai_tmp_x' || s_id || ' AS ' || proj_cmd
|| CASE WHEN from_cmd IS NULL THEN '' ELSE ' ' || from_cmd END;
EXCEPTION WHEN undefined_table THEN
GET STACKED DIAGNOSTICS e_message = MESSAGE_TEXT;
-- during function invocation, search path is redirected to {pg_temp, pg_catalog, function_schema} and becomes immutable
RAISE INFO 'could not resolve relation % using system-defined "search_path" setting during function invocation: ''%''',
substr(e_message, 10, 1 + strpos(substr(e_message,11), '" does not exist')),
array_to_string(current_schemas(TRUE),', ')
USING HINT = 'snapshots require schema-qualified table references, e.g. schema_name.table_name';
RAISE;
END;
-- extract normalized projection list
i_commands := ARRAY[proj_cmd, from_cmd, dist_cmd, '', ''];
SELECT string_agg(ident, ', '),
string_agg(ident || ' AS f' || ordinal_position, ', '),
string_agg('t' || s_id || '.f' || ordinal_position || ' AS ' || ident, ', ')
FROM ( SELECT ordinal_position, quote_ident(column_name) AS ident
FROM information_schema.columns
WHERE table_schema = (SELECT nspname FROM pg_namespace WHERE oid=pg_my_temp_schema())
AND table_name = '_db4ai_tmp_x' || s_id
ORDER BY ordinal_position
) INTO STRICT proj_cmd, i_commands[4], i_commands[5];
IF proj_cmd IS NULL THEN
RAISE EXCEPTION 'create snapshot internal error1: %', s_id;
END IF;
-- finalize the snapshot using elevated privileges
PERFORM db4ai.create_snapshot_internal(s_id, i_schema, i_name, i_commands, i_comment, CURRENT_USER);
-- drop temporary view used for privilege transfer
EXECUTE 'DROP TABLE _db4ai_tmp_x' || s_id;
-- create custom view, owned by current user
EXECUTE 'CREATE VIEW ' || qual_name || ' WITH(security_barrier) AS SELECT ' || proj_cmd || ' FROM db4ai.v' || s_id;
EXECUTE 'COMMENT ON VIEW ' || qual_name || ' IS ''snapshot view backed by db4ai.v' || s_id
|| CASE WHEN length(i_comment) > 0 THEN ' comment is "' || i_comment || '"' ELSE '' END || '''';
EXECUTE 'ALTER VIEW ' || qual_name || ' OWNER TO ' || CURRENT_USER;
-- return final snapshot name
res := ROW(i_schema, i_name);
-- PG BUG: PG 9.2 cannot return composite type, only a reference to a variable of composite type
return res;
END;
$$;
COMMENT ON FUNCTION db4ai.create_snapshot() IS 'Create a new snapshot';

View File

@ -0,0 +1,37 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* deploy.sql
* Deploy DB4AI.Snapshot functionality.
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/snapshots/deploy.sql
*
* -------------------------------------------------------------------------
*/
-- db local objects
\set ON_ERROR_STOP 1
BEGIN;
\ir schema.sql
\ir create.sql
\ir prepare.sql
\ir sample.sql
\ir publish.sql
\ir purge.sql
COMMIT;

View File

@ -0,0 +1,913 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* prepare.sql
* Prepare DB4AI.Snapshot functionality.
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/snapshots/prepare.sql
*
* -------------------------------------------------------------------------
*/
CREATE OR REPLACE FUNCTION db4ai.prepare_snapshot_internal(
IN s_id BIGINT, -- snapshot id
IN p_id BIGINT, -- parent id
IN m_id BIGINT, -- matrix id
IN r_id BIGINT, -- root id
IN i_schema NAME, -- snapshot namespace
IN i_name NAME, -- snapshot name
IN i_commands TEXT[], -- DDL and DML commands defining snapshot modifications
IN i_comment TEXT, -- snapshot description
IN i_owner NAME, -- snapshot owner
INOUT i_idx INT, -- index for exec_cmds
INOUT i_exec_cmds TEXT[], -- DDL and DML for execution
IN i_mapping NAME[] DEFAULT NULL -- mapping of user columns to backing column; generate rules if not NULL
)
RETURNS RECORD LANGUAGE plpgsql SECURITY DEFINER SET search_path = pg_catalog, pg_temp
AS $$
DECLARE
command_str TEXT; -- command string for iterator
e_stack_act TEXT; -- current stack for validation
row_count BIGINT; -- number of rows in this snapshot
BEGIN
BEGIN
RAISE EXCEPTION 'SECURITY_STACK_CHECK';
EXCEPTION WHEN OTHERS THEN
GET STACKED DIAGNOSTICS e_stack_act = PG_EXCEPTION_CONTEXT;
IF CURRENT_SCHEMA = 'db4ai' THEN
e_stack_act := replace(e_stack_act, ' prepare_snapshot(', ' db4ai.prepare_snapshot(');
e_stack_act := replace(e_stack_act, ' prepare_snapshot_internal(', ' db4ai.prepare_snapshot_internal(');
e_stack_act := replace(e_stack_act, ' sample_snapshot(', ' db4ai.sample_snapshot(');
END IF;
IF e_stack_act LIKE E'referenced column: i_idx\n'
'SQL statement "SELECT (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,\n'
' CURRENT_USER, idx, exec_cmds)).i_idx"\n%'
THEN
e_stack_act := substr(e_stack_act, 200);
END IF;
IF e_stack_act NOT SIMILAR TO 'PL/pgSQL function db4ai.prepare_snapshot\(name,name,text\[\],name,text\) line (175|541|607|714) at assignment%'
AND e_stack_act NOT LIKE 'PL/pgSQL function db4ai.sample_snapshot(name,name,name[],numeric[],name[],text[]) line 215 at IF%'
THEN
RAISE EXCEPTION 'direct call to db4ai.prepare_snapshot_internal(bigint,bigint,bigint,bigint,name,name,text[],text,name,'
'int,text[],name[]) is not allowed'
USING HINT = 'call public interface db4ai.prepare_snapshot instead';
END IF;
END;
--generate rules from the mapping
IF i_mapping IS NOT NULL THEN
DECLARE
sel_view TEXT := 'CREATE OR REPLACE VIEW db4ai.v' || s_id || ' WITH(security_barrier) AS SELECT ';
ins_grnt TEXT := 'GRANT INSERT (';
ins_rule TEXT := 'CREATE OR REPLACE RULE _INSERT AS ON INSERT TO db4ai.v' || s_id || ' DO INSTEAD INSERT INTO '
'db4ai.t' || coalesce(m_id, s_id) || '(';
ins_vals TEXT := ' VALUES (';
upd_grnt TEXT;
upd_rule TEXT;
dist_key NAME[] := array_agg(coalesce(m[1], replace(m[2], '""', '"'))) FROM regexp_matches(
getdistributekey('db4ai.t' || coalesce(m_id, p_id)),'([^\s",]+)|"((?:[^"]*"")*[^"]*)"', 'g') m;
BEGIN
FOR idx IN 3 .. array_length(i_mapping, 1) BY 3 LOOP
IF idx = 3 THEN
ins_grnt := ins_grnt || quote_ident(i_mapping[idx]);
ins_rule := ins_rule || coalesce(i_mapping[idx-2], i_mapping[idx-1]);
ins_vals := ins_vals || 'new.' || quote_ident(i_mapping[idx]);
ELSE
sel_view := sel_view || ', ';
ins_grnt := ins_grnt || ', ' || quote_ident(i_mapping[idx]);
ins_rule := ins_rule || ', ' || coalesce(i_mapping[idx-2], i_mapping[idx-1]);
ins_vals := ins_vals || ', ' || 'new.' || quote_ident(i_mapping[idx]);
END IF;
IF i_mapping[idx-2] IS NULL THEN -- handle shared columns without private (only CSS)
sel_view := sel_view || i_mapping[idx-1];
ELSE
IF i_mapping[idx-1] IS NULL THEN -- handle sole private column (all MSS and added CSS columns)
sel_view := sel_view || i_mapping[idx-2];
ELSE -- handle shadowing (CSS CASE)
sel_view := sel_view || 'coalesce(' || i_mapping[idx-2] || ', ' || i_mapping[idx-1] || ')';
END IF;
IF dist_key IS NULL OR NOT i_mapping[idx-2] = ANY(dist_key) THEN -- no updates on DISTRIBUTE BY columns
upd_grnt := CASE WHEN upd_grnt IS NULL -- grant update only on private column
THEN 'GRANT UPDATE (' ELSE upd_grnt ||', ' END || quote_ident(i_mapping[idx]);
upd_rule := CASE WHEN upd_rule IS NULL -- update only private column
THEN 'CREATE OR REPLACE RULE _UPDATE AS ON UPDATE TO db4ai.v' || s_id || ' DO INSTEAD UPDATE db4ai.t'
|| coalesce(m_id, s_id) || ' SET '
ELSE upd_rule || ', ' END
|| i_mapping[idx-2] || '=new.' || quote_ident(i_mapping[idx]); -- update private column
END IF;
END IF;
sel_view := sel_view || ' AS ' || quote_ident(i_mapping[idx]);
END LOOP;
i_exec_cmds := i_exec_cmds || ARRAY [
[ 'O', sel_view || ', xc_node_id, ctid FROM db4ai.t' || coalesce(m_id, s_id)
|| CASE WHEN m_id IS NULL THEN '' ELSE ' WHERE _' || s_id END ],
[ 'O', 'GRANT SELECT, DELETE ON db4ai.v' || s_id || ' TO ' || i_owner ],
[ 'O', ins_grnt || ') ON db4ai.v' || s_id || ' TO ' || i_owner ],
[ 'O', ins_rule || CASE WHEN m_id IS NULL THEN ')' ELSE ', _' || s_id || ')' END || ins_vals
|| CASE WHEN m_id IS NULL THEN ')' ELSE ', TRUE)' END ],
[ 'O', 'CREATE OR REPLACE RULE _DELETE AS ON DELETE TO db4ai.v' || s_id || ' DO INSTEAD '
|| CASE WHEN m_id IS NULL THEN 'DELETE FROM db4ai.t' || s_id ELSE 'UPDATE db4ai.t' || m_id || ' SET _' || s_id || '=FALSE' END
|| ' WHERE t' || coalesce(m_id, s_id) || '.xc_node_id=old.xc_node_id AND t' || coalesce(m_id, s_id) || '.ctid=old.ctid' ] ];
IF upd_rule IS NOT NULL THEN
i_exec_cmds := i_exec_cmds || ARRAY [
[ 'O', upd_grnt || ') ON db4ai.v' || s_id || ' TO ' || i_owner ],
[ 'O', upd_rule || ' WHERE t' || coalesce(m_id, s_id) || '.xc_node_id=old.xc_node_id AND t' || coalesce(m_id, s_id) || '.ctid=old.ctid' ]];
END IF;
RETURN;
END;
END IF;
-- Execute the queries
LOOP EXIT WHEN i_idx = 1 + array_length(i_exec_cmds, 1);
CASE i_exec_cmds[i_idx][1]
WHEN 'O' THEN
-- RAISE NOTICE 'owner executing: %', i_exec_cmds[i_idx][2];
EXECUTE i_exec_cmds[i_idx][2];
i_idx := i_idx + 1;
WHEN 'U' THEN
RETURN;
ELSE -- this should never happen
RAISE EXCEPTION 'prepare snapshot internal error2: % %', idx, i_exec_cmds[idx];
END CASE;
END LOOP;
EXECUTE 'DROP RULE IF EXISTS _INSERT ON db4ai.v' || s_id;
EXECUTE 'DROP RULE IF EXISTS _UPDATE ON db4ai.v' || s_id;
EXECUTE 'DROP RULE IF EXISTS _DELETE ON db4ai.v' || s_id;
EXECUTE 'COMMENT ON VIEW db4ai.v' || s_id || ' IS ''snapshot ' || quote_ident(i_schema) || '.' || quote_ident(i_name)
|| ' backed by db4ai.t' || coalesce(m_id, s_id) || CASE WHEN length(i_comment) > 0 THEN ' comment is "' || i_comment
|| '"' ELSE '' END || '''';
EXECUTE 'REVOKE ALL PRIVILEGES ON db4ai.v' || s_id || ' FROM ' || i_owner;
EXECUTE 'GRANT SELECT ON db4ai.v' || s_id || ' TO ' || i_owner || ' WITH GRANT OPTION';
EXECUTE 'SELECT COUNT(*) FROM db4ai.v' || s_id INTO STRICT row_count;
INSERT INTO db4ai.snapshot (id, parent_id, matrix_id, root_id, schema, name, owner, commands, comment, row_count)
VALUES (s_id, p_id, m_id, r_id, i_schema, i_name, i_owner, i_commands, i_comment, row_count);
END;
$$;
CREATE OR REPLACE FUNCTION db4ai.prepare_snapshot(
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
IN i_parent NAME, -- parent snapshot name
IN i_commands TEXT[], -- DDL and DML commands defining snapshot modifications
IN i_vers NAME DEFAULT NULL, -- override version postfix
IN i_comment TEXT DEFAULT NULL -- description of this unit of data curation
)
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
AS $$
DECLARE
s_id BIGINT; -- snapshot id
r_id BIGINT; -- root id
m_id BIGINT; -- matrix id
p_id BIGINT; -- parent id
c_id BIGINT; -- column id for backing table
s_name NAME; -- current snapshot name
s_mode VARCHAR(3); -- current snapshot mode
s_vers_del CHAR; -- snapshot version delimiter, default '@'
s_vers_sep CHAR; -- snapshot version separator, default '.'
s_uv_proj TEXT; -- snapshot user view projection list
p_name_vers TEXT[]; -- split full parent name into name and version
p_sv_proj TEXT; -- parent snapshot system view projection list
command_str TEXT; -- command string for iterator
pattern TEXT; -- command pattern for matching
ops_arr BOOLEAN[]; -- operation classes in i_commands
ops_str TEXT[] := '{ALTER, INSERT, DELETE, UPDATE}'; -- operation classes as string
ALTER_OP INT := 1; -- ALTER operation class
INSERT_OP INT := 2; -- INSERT operation class
DELETE_OP INT := 3; -- DELETE operation class
UPDATE_OP INT := 4; -- UPDATE operation class
vers_arr INT[]; -- split version digits
exec_cmds TEXT[]; -- commands for execution
qual_name TEXT; -- qualified snapshot name
ALTER_CLAUSE INT := 1; -- ALTER clause class
WHERE_CLAUSE INT := 2; -- WHERE clause class - for insert, update and delete
FROM_CLAUSE INT := 3; -- FROM clause class - for insert, update and delete (as USING)
SET_CLAUSE INT := 4; -- SET clause class - for updates and insert
-- (as generic SQL: projection, list, VALUES, ...)
AS_CLAUSE INT := 5; -- AS clause class - for delete and update
-- (correlation name) - default is "snapshot"
current_op INT; -- currently parsed operation class
next_op INT; -- following operation class
current_clauses TEXT[]; -- clauses for current operation
next_clauses TEXT[]; -- clauses for next operation
mapping NAME[]; -- mapping user column names to backing column names
newmap BOOLEAN := FALSE; -- mapping has changed
res db4ai.snapshot_name; -- composite result
BEGIN
-- obtain active message level
BEGIN
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
RAISE INFO 'effective client_min_messages is %', upper(current_setting('db4ai.message_level'));
EXCEPTION WHEN OTHERS THEN
END;
-- obtain active snapshot mode
BEGIN
s_mode := upper(current_setting('db4ai_snapshot_mode'));
EXCEPTION WHEN OTHERS THEN
s_mode := 'MSS';
END;
IF s_mode NOT IN ('CSS', 'MSS') THEN
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
END IF;
-- obtain relevant configuration parameters
BEGIN
s_vers_del := upper(current_setting('db4ai_snapshot_version_delimiter'));
EXCEPTION WHEN OTHERS THEN
s_vers_del := '@';
END;
BEGIN
s_vers_sep := upper(current_setting('db4ai_snapshot_version_separator'));
EXCEPTION WHEN OTHERS THEN
s_vers_sep := '.';
END;
-- check all input parameters
IF i_schema IS NULL OR i_schema = '' THEN
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
END IF;
IF i_parent IS NULL OR i_parent = '' THEN
RAISE EXCEPTION 'i_parent cannot be NULL or empty';
ELSE
i_parent := replace(i_parent, chr(1), s_vers_del);
i_parent := replace(i_parent, chr(2), s_vers_sep);
p_name_vers := regexp_split_to_array(i_parent, s_vers_del);
IF array_length(p_name_vers, 1) <> 2 OR array_length(p_name_vers, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_parent must contain exactly one ''%'' character', s_vers_del
USING HINT = 'reference a snapshot using the format: snapshot_name' || s_vers_del || 'version';
END IF;
END IF;
-- check if parent exists
BEGIN
SELECT id, matrix_id, root_id FROM db4ai.snapshot WHERE schema = i_schema AND name = i_parent INTO STRICT p_id, m_id, r_id;
EXCEPTION WHEN NO_DATA_FOUND THEN
RAISE EXCEPTION 'parent snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_parent);
END;
--SELECT nextval('db4ai.snapshot_sequence') INTO STRICT s_id;
SELECT MAX(id)+1 FROM db4ai.snapshot INTO STRICT s_id; -- openGauss BUG: cannot create sequences in initdb
-- extract highest used c_id from existing backing table or parent ()
-- cannot use information_schema here, because the current user has no read permission on the backing table
SELECT 1 + max(ltrim(attname, 'f')::BIGINT) FROM pg_catalog.pg_attribute INTO STRICT c_id
WHERE attrelid = ('db4ai.t' || coalesce(m_id, p_id))::regclass AND attnum > 0 AND NOT attisdropped AND attname like 'f%';
IF c_id IS NULL THEN
RAISE EXCEPTION 'prepare snapshot internal error3: %', coalesce(m_id, p_id);
END IF;
IF i_commands IS NULL OR array_length(i_commands, 1) IS NULL OR array_length(i_commands, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_commands array malformed'
USING HINT = 'pass SQL DML and DDL operations as TEXT[] literal, e.g. ''{ALTER, ADD a int, DROP c, DELETE, '
'WHERE b=5, INSERT, FROM t, UPDATE, FROM t, SET x=y, SET z=f(z), WHERE t.u=v}''';
END IF;
-- extract normalized projection list
p_sv_proj := substring(pg_get_viewdef('db4ai.v' || p_id), '^SELECT (.*), t[0-9]+\.xc_node_id, t[0-9]+\.ctid FROM.*$');
mapping := array(SELECT unnest(ARRAY[ m[1], m[2], coalesce(m[3], replace(m[4],'""','"'))])
FROM regexp_matches(p_sv_proj, CASE s_mode WHEN 'CSS'
-- inherited CSS columns are shared (private: nullable, shared: not null, user_cname: not null)
THEN '(?:COALESCE\(t[0-9]+\.(f[0-9]+), )?t[0-9]+\.(f[0-9]+)(?:\))? AS (?:([^\s",]+)|"((?:[^"]*"")*[^"]*)")'
-- all MSS columns are private (privte: not null, shared: nullable, user_cname: not null)
ELSE '(?:COALESCE\()?t[0-9]+\.(f[0-9]+)(?:, t[0-9]+\.(f[0-9]+)\))? AS (?:([^\s",]+)|"((?:[^"]*"")*[^"]*)")'
END, 'g') m);
-- In principle two MSS naming conventions are possible:
-- (a) plain column names for MSS, allowing direct operations, but only with CompactSQL, not with TrueSQL. Conversion to CSS
-- then needs to rename user columns (if they are in f[0-9]+) or simply add columns max fXX + 1
-- (b) translated columns names for MSS and CSS. Simple MSS->CSS conversion. No direct operations, always using rewrite.
-- This is the more general approach!
-- The need for rewriting using rules:
-- UPDATE SET AS T SET T.x=y, "T.x"=y, “_$%_\\’”=NULL, T.z=DEFAULT, (a, b, T.c) = (SELECT 1 a, 2 b, 3 c)
-- FROM H AS I, J as K
-- WHERE _x_=5 AND _T.z_=5 AND v='T.z' AND (SELECT x, T.z FROM A as T)
-- unqualified, ambiguous, quoted, string literals, ... in SET clause maybe still manageable but in WHERE
-- no guarantee for correctness possible -> need to use system's SQL parser with rewrite rules! */
-- create / upgrade + prepare target snapshots for SQL DML/DDL operations
IF s_mode = 'MSS' THEN
DECLARE
s_bt_proj TEXT; -- snapshot backing table projection list
s_bt_dist TEXT; -- DISTRIBUTE BY clause for creating backing table
BEGIN
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
s_bt_proj := s_bt_proj || quote_ident(mapping[idx]) || ' AS ' || mapping[idx-2] || ',';
END LOOP;
s_bt_dist := getdistributekey('db4ai.t' || coalesce(m_id, p_id));
s_bt_dist := CASE WHEN s_bt_dist IS NULL
THEN ' DISTRIBUTE BY REPLICATION'
ELSE ' DISTRIBUTE BY HASH(' || s_bt_dist || ')' END; s_bt_dist := ''; -- we silently drop DISTRIBUTE_BY
exec_cmds := ARRAY [
[ 'O', 'CREATE TABLE db4ai.t' || s_id || ' WITH (orientation = column, compression = low)'
-- extract and propagate DISTRIBUTE BY from parent
|| s_bt_dist || ' AS SELECT ' || rtrim(s_bt_proj, ',') || ' FROM db4ai.v' || p_id ]];
END;
ELSIF s_mode = 'CSS' THEN
IF m_id IS NULL THEN
exec_cmds := ARRAY [
[ 'O', 'UPDATE db4ai.snapshot SET matrix_id = ' || p_id || ' WHERE schema = ''' || i_schema || ''' AND name = '''
|| i_parent || '''' ],
[ 'O', 'ALTER TABLE db4ai.t' || p_id || ' ADD _' || p_id || ' BOOLEAN NOT NULL DEFAULT TRUE' ],
[ 'O', 'ALTER TABLE db4ai.t' || p_id || ' ALTER _' || p_id || ' SET DEFAULT FALSE' ],
[ 'O', 'CREATE OR REPLACE VIEW db4ai.v' || p_id || ' WITH(security_barrier) AS SELECT ' || p_sv_proj || ', xc_node_id, ctid FROM db4ai.t'
|| p_id || ' WHERE _' || p_id ]];
m_id := p_id;
END IF;
exec_cmds := exec_cmds || ARRAY [
[ 'O', 'ALTER TABLE db4ai.t' || m_id || ' ADD _' || s_id || ' BOOLEAN NOT NULL DEFAULT FALSE' ],
[ 'O', 'UPDATE db4ai.t' || m_id || ' SET _' || s_id || ' = TRUE WHERE _' || p_id ]];
END IF;
-- generate and append grant, create view and rewrite rules for new snapshot
exec_cmds := (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,
CURRENT_USER, NULL, exec_cmds, mapping)).i_exec_cmds;
FOREACH command_str IN ARRAY i_commands LOOP
IF command_str IS NULL THEN
RAISE EXCEPTION 'i_commands array contains NULL values';
END IF;
END LOOP;
-- apply SQL DML/DDL according to snapshot mode
FOREACH command_str IN ARRAY (i_commands || ARRAY[NULL] ) LOOP
command_str := btrim(command_str);
pattern := upper(regexp_replace(left(command_str, 30), '\s+', ' ', 'g'));
IF pattern is NULL THEN
next_op := NULL;
ELSIF pattern = 'ALTER' THEN -- ALTER keyword is optional
next_op := ALTER_OP;
ELSIF pattern = 'DELETE' THEN
next_op := DELETE_OP;
ELSIF pattern = 'INSERT' THEN
next_op := INSERT_OP;
ELSIF pattern = 'UPDATE' THEN
next_op := UPDATE_OP;
ELSIF left(pattern, 7) = 'DELETE ' THEN
next_op := DELETE_OP;
SELECT coalesce(m[1], m[2]), m [3] FROM regexp_matches(command_str,
'^\s*DELETE\s+FROM\s*(?: snapshot |"snapshot")\s*(?:AS\s*(?: ([^\s"]+) |"((?:[^"]*"")*[^"]*)")\s*)?(.*)\s*$', 'i') m
INTO next_clauses[AS_CLAUSE], next_clauses[FROM_CLAUSE];
RAISE NOTICE E'XXX DELETE \n%\n%', command_str, array_to_string(next_clauses, E'\n');
ELSIF left(pattern, 7) = 'INSERT ' THEN
next_op := INSERT_OP;
SELECT coalesce(m[1], m[2]), m [3] FROM regexp_matches(command_str,
'^\s*INSERT\s+INTO\s*(?: snapshot |"snapshot")\s*(.*)\s*$', 'i') m
INTO STRICT next_clauses[SET_CLAUSE];
RAISE NOTICE E'XXX INSERT \n%\n%', command_str, array_to_string(next_clauses, E'\n');
ELSIF left(pattern, 7) = 'UPDATE ' THEN
next_op := UPDATE_OP;
SELECT coalesce(m[1], m[2]), m [3] FROM regexp_matches(command_str,
'^\s*UPDATE\s*(?: snapshot |"snapshot")\s*(?:AS\s*(?: ([^\s"]+) |"((?:[^"]*"")*[^"]*)")\s*)?(.*)\s*$', 'i') m
INTO STRICT next_clauses[AS_CLAUSE], next_clauses[SET_CLAUSE];
DECLARE
nested INT := 0; -- level of nesting
quoted BOOLEAN := FALSE; -- inside quoted identifier
cur_ch VARCHAR; -- current character in tokenizer
idx INTEGER := 0; -- loop counter, cannot use FOR .. iterator
start INTEGER := 1;
stmt TEXT := next_clauses[SET_CLAUSE];
BEGIN
-- BEGIN splitter code for testing
pattern := '';
LOOP
idx := idx + 1;
cur_ch := substr(stmt, idx, 1);
EXIT WHEN cur_ch IS NULL OR cur_ch = '';
CASE cur_ch
WHEN '"' THEN
IF quoted AND substr(stmt, idx + 1, 1) = '"' THEN
idx := idx + 1;
ELSE
quoted := NOT quoted;
END IF;
IF quoted THEN
CONTINUE;
END IF;
WHEN '(' THEN
nested := nested + 1;
CONTINUE;
WHEN ')' THEN
nested := nested - 1;
IF nested < 0 THEN
RAISE EXCEPTION 'syntax error at or near '')'' in ''%'' at position ''%''', stmt, idx;
END IF;
CONTINUE;
WHEN ' ' THEN
IF quoted OR nested > 0 THEN
CONTINUE;
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
start := idx;
CONTINUE;
END IF;
ELSE
pattern := pattern || upper(cur_ch);
CONTINUE;
END CASE;
-- END splitter code for testing
IF pattern IN ('FROM', 'WHERE') THEN
next_clauses[FROM_CLAUSE] := substr(next_clauses[SET_CLAUSE], start + 1);
next_clauses[SET_CLAUSE] := left(next_clauses[SET_CLAUSE], start - 1);
EXIT;
END IF;
pattern := '';
start := idx;
END LOOP;
END;
RAISE NOTICE E'XXX UPDATE \n%\n%', command_str, array_to_string(next_clauses, E'\n');
ELSIF left(pattern, 6) = 'ALTER ' THEN
SELECT coalesce(m[1], m[2]), m [3] FROM regexp_matches(command_str,
'^\s*ALTER\s+TABLE\s*(?: snapshot |"snapshot")\s*(.*)\s*$', 'i') m
INTO STRICT next_clauses[ALTER_CLAUSE];
RAISE NOTICE E'XXX ALTER \n%\n%', command_str, array_to_string(next_clauses, E'\n');
IF current_op IS NULL OR current_clauses[ALTER_CLAUSE] IS NULL THEN
next_op := ALTER_OP;
ELSE
current_clauses[ALTER_CLAUSE] := current_clauses[ALTER_CLAUSE] || ', ' || next_clauses[ALTER_CLAUSE];
next_clauses[ALTER_CLAUSE] := NULL;
END IF;
ELSIF left(pattern, 4) = 'ADD ' OR left(pattern, 5) = 'DROP ' THEN
--for chaining, conflicting ALTER ops must be avoided by user
IF current_op IS NULL OR current_op <> ALTER_OP THEN
next_op := ALTER_OP; -- ALTER keyword is optional
next_clauses[ALTER_CLAUSE] := command_str;
ELSIF current_clauses[ALTER_CLAUSE] IS NULL THEN
current_clauses[ALTER_CLAUSE] := command_str;
CONTINUE; -- allow chaining of ALTER ops
ELSE
current_clauses[ALTER_CLAUSE] := current_clauses[ALTER_CLAUSE] || ', ' || command_str;
CONTINUE; -- allow chaining of ALTER ops
END IF;
ELSIF left(pattern, 6) = 'WHERE ' THEN
IF current_op IS NULL THEN
RAISE EXCEPTION 'missing INSERT / UPDATE / DELETE keyword before WHERE clause in i_commands at: ''%''', command_str;
ELSIF current_op NOT IN (INSERT_OP, UPDATE_OP, DELETE_OP) THEN
RAISE EXCEPTION 'illegal WHERE clause in % at: ''%''', ops_str[current_op], command_str;
ELSIF current_clauses[WHERE_CLAUSE] IS NULL THEN
current_clauses[WHERE_CLAUSE] := command_str;
ELSE
RAISE EXCEPTION 'multiple WHERE clauses in % at: ''%''', ops_str[current_op], command_str;
END IF;
CONTINUE;
ELSIF left(pattern, 5) = 'FROM ' THEN
IF current_op IS NULL THEN
RAISE EXCEPTION 'missing INSERT / UPDATE keyword before FROM clause in i_commands at: ''%''', command_str;
ELSIF current_op NOT IN (INSERT_OP, UPDATE_OP) THEN
RAISE EXCEPTION 'illegal FROM clause in % at: ''%''', ops_str[current_op], command_str;
ELSIF current_clauses[FROM_CLAUSE] IS NULL THEN
current_clauses[FROM_CLAUSE] := command_str;
ELSE
RAISE EXCEPTION 'multiple FROM clauses in % at: ''%''', ops_str[current_op], command_str;
END IF;
CONTINUE;
ELSIF left(pattern, 6) = 'USING ' THEN
IF current_op IS NULL THEN
RAISE EXCEPTION 'missing DELETE keyword before USING clause in i_commands at: ''%''', command_str;
ELSIF current_op NOT IN (DELETE_OP) THEN
RAISE EXCEPTION 'illegal USING clause in % at: ''%''', ops_str[current_op], command_str;
ELSIF current_clauses[FROM_CLAUSE] IS NULL THEN
current_clauses[FROM_CLAUSE] := command_str;
ELSE
RAISE EXCEPTION 'multiple USING clauses in DELETE at: ''%''', command_str;
END IF;
CONTINUE;
ELSIF left(pattern, 4) = 'SET ' THEN
IF current_op IS NULL THEN
RAISE EXCEPTION 'missing UPDATE keyword before SET clause in i_commands at: ''%''', command_str;
ELSIF current_op NOT IN (UPDATE_OP) THEN
RAISE EXCEPTION 'illegal SET clause in % at: ''%''', ops_str[current_op], command_str;
--for chaining, conflicting assignments must be avoided by user
ELSE -- allow chaining of SET
current_clauses[SET_CLAUSE] := CASE WHEN current_clauses[SET_CLAUSE] IS NULL
THEN command_str ELSE current_clauses[SET_CLAUSE] || ' ' || command_str END;
END IF;
CONTINUE;
ELSIF left(pattern, 3) = 'AS ' THEN
IF current_op IS NULL THEN
RAISE EXCEPTION 'missing UPDATE / DELETE keyword before AS clause in i_commands at: ''%''', command_str;
ELSIF current_op NOT IN (UPDATE_OP, DELETE_OP) THEN
RAISE EXCEPTION 'illegal AS clause in % at: ''%''', ops_str[current_op], command_str;
ELSIF current_clauses[AS_CLAUSE] IS NULL THEN
DECLARE
as_pos INT := 3 + strpos(upper(command_str), 'AS ');
BEGIN
current_clauses[AS_CLAUSE] := ltrim(substr(command_str, as_pos));
END;
ELSE
RAISE EXCEPTION 'multiple AS clauses in % at: ''%''', ops_str[current_op], command_str;
END IF;
CONTINUE;
ELSE -- generic SQL allowed only in INSERT
IF current_op IS NULL THEN
RAISE EXCEPTION 'missing ALTER / INSERT / UPDATE / DELETE keyword before SQL clause: ''%''', command_str;
ELSIF current_op NOT IN (INSERT_OP) THEN
RAISE EXCEPTION 'illegal SQL clause in % at: ''%''', ops_str[current_op], command_str;
ELSE -- allow chaining of generic SQL for PROJ LIST, VALUES, ...
current_clauses[SET_CLAUSE] := CASE WHEN current_clauses[SET_CLAUSE] IS NULL
THEN command_str ELSE current_clauses[SET_CLAUSE] || ' ' || command_str END;
END IF;
CONTINUE;
END IF;
IF current_op IS NOT NULL THEN
IF current_clauses IS NULL AND current_op NOT IN (DELETE_OP) THEN
RAISE EXCEPTION 'missing auxiliary clauses in %',
CASE WHEN command_str IS NULL THEN ops_str[current_op] ELSE ops_str[current_op] || ' before: '''
|| command_str || '''' END;
END IF;
IF current_clauses[AS_CLAUSE] IS NULL THEN
current_clauses[AS_CLAUSE] := 'snapshot';
END IF;
ops_arr[current_op] := TRUE;
IF current_op = ALTER_OP THEN
command_str := NULL; -- stores DDL statement for adding / dropping columns
DECLARE
dropif TEXT := NULL; -- drop if exists
expect BOOLEAN := TRUE; -- expect keyword
alt_op VARCHAR(4); -- alter operation: NULL or 'ADD' or 'DROP'
quoted BOOLEAN := FALSE; -- inside quoted identifier
cur_ch VARCHAR; -- current character in tokenizer
idx INTEGER := 0; -- loop counter, cannot use FOR .. iterator
tokens TEXT := current_clauses[ALTER_CLAUSE] || ',';
BEGIN
-- BEGIN tokenizer code for testing
pattern := '';
LOOP
idx := idx + 1;
cur_ch := substr(tokens, idx, 1);
EXIT WHEN cur_ch IS NULL OR cur_ch = '';
CASE cur_ch
WHEN '"' THEN
IF quoted AND substr(tokens, idx + 1, 1) = '"' THEN
pattern := pattern || '"';
idx := idx + 1;
ELSE
quoted := NOT quoted;
END IF;
IF quoted THEN
CONTINUE;
END IF;
WHEN ',' THEN
IF quoted THEN
pattern := pattern || cur_ch;
CONTINUE;
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
pattern := ',';
ELSE
idx := idx - 1; -- reset on comma for next loop
END IF;
WHEN ' ', E'\n', E'\t' THEN
IF quoted THEN
pattern := pattern || cur_ch;
CONTINUE;
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
CONTINUE;
END IF;
ELSE
pattern := pattern || CASE WHEN quoted THEN cur_ch ELSE lower(cur_ch) END;
CONTINUE;
END CASE;
-- END tokenizer code for testing
IF alt_op = 'DROP' AND upper(dropif) = 'IF' THEN
IF pattern = ',' THEN
pattern := dropif; -- interpret 'if' as column name (not a keyword)
idx := idx - 1; -- reset on comma for next loop
ELSIF upper(pattern) <> 'EXISTS' THEN
RAISE EXCEPTION 'expected EXISTS keyword in % operation after ''%'' in: ''%''',
alt_op, dropif, current_clauses[ALTER_CLAUSE];
END IF;
END IF;
IF expect THEN
IF upper(pattern) IN ('ADD', 'DROP') THEN
IF alt_op IS NULL THEN
alt_op := upper(pattern);
expect := FALSE;
ELSE
RAISE EXCEPTION 'unable to extract column name in % operation: ''%''',
alt_op, current_clauses[ALTER_CLAUSE];
END IF;
ELSE
RAISE EXCEPTION 'expected ADD or DROP keyword before ''%'' in: ''%''',
pattern, current_clauses[ALTER_CLAUSE]
USING HINT = 'currently only ADD and DROP supported';
END IF;
ELSIF pattern = ',' THEN
expect := TRUE; -- allow chaining of ALTER ops
ELSIF alt_op IS NULL THEN
-- accept all possibly legal text following column name
-- leave exact syntax check to SQL compiler
IF command_str IS NOT NULL THEN
command_str := command_str || ' ' || pattern;
END IF;
ELSIF upper(pattern) = 'COLUMN' THEN
-- skip keyword COLUMN between ADD/DROP and column name
ELSIF alt_op = 'DROP' AND upper(pattern) = 'IF' AND dropif IS NULL THEN
dropif := pattern; -- 'IF' is not a keyword
ELSIF alt_op = 'DROP' AND upper(pattern) = 'EXISTS' AND upper(dropif) = 'IF' THEN
dropif := pattern; -- 'EXISTS' is not a keyword
ELSIF alt_op IN ('ADD', 'DROP') THEN
-- attempt to map the pattern
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
IF pattern = mapping[idx] THEN
IF alt_op = 'ADD' THEN
-- check if pattern was mapped to an existing column
RAISE EXCEPTION 'column "%" already exists in current snapshot', pattern;
ELSIF alt_op = 'DROP' THEN
-- DROP a private column (MSS and CSS)
IF mapping[idx-2] IS NOT NULL THEN
command_str := CASE WHEN command_str IS NULL
THEN 'ALTER TABLE db4ai.t' || coalesce(m_id, s_id)
ELSE command_str || ',' END || ' DROP ' || mapping[idx-2];
END IF;
mapping := mapping[1:(idx-3)] || mapping[idx+1:(array_length(mapping, 1))];
newmap := TRUE;
alt_op := NULL;
EXIT;
END IF;
END IF;
END LOOP;
-- apply the mapping
IF alt_op = 'ADD' THEN
-- ADD a private column (MSS and CSS)
command_str := CASE WHEN command_str IS NULL
THEN 'ALTER TABLE db4ai.t' || coalesce(m_id, s_id)
ELSE command_str || ',' END || ' ADD f' || c_id;
mapping := mapping || ARRAY [ 'f' || c_id, NULL, pattern ]::NAME[];
newmap := TRUE;
c_id := c_id + 1;
ELSIF alt_op = 'DROP' THEN
-- check whether pattern needs mapping to an existing column
IF dropif IS NULL OR upper(dropif) <> 'EXISTS' THEN
RAISE EXCEPTION 'unable to map field "%" to backing table in % operation: ''%''',
pattern, alt_op, current_clauses[ALTER_CLAUSE];
END IF;
END IF;
dropif := NULL;
alt_op := NULL;
ELSE
-- checked before, this should never happen
RAISE EXCEPTION 'unexpected ALTER clause: %', alt_op;
END IF;
pattern := '';
END LOOP;
IF quoted THEN
RAISE EXCEPTION 'unterminated quoted identifier ''"%'' at or near: ''%''',
substr(pattern, 1, char_length(pattern)-1), current_clauses[ALTER_CLAUSE];
END IF;
-- CREATE OR REPLACE: cannot drop columns from view - MUST use DROP / CREATE
-- clear view dependencies for backing table columns
exec_cmds := exec_cmds || ARRAY [ 'O', 'DROP VIEW IF EXISTS db4ai.v' || s_id ];
-- append the DDL statement for the backing table (if any)
IF command_str IS NOT NULL THEN
exec_cmds := exec_cmds || ARRAY [ 'O', command_str ];
END IF;
IF newmap THEN
-- generate and append grant, create view and rewrite rules for new snapshot
exec_cmds := (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,
CURRENT_USER, NULL, exec_cmds, mapping)).i_exec_cmds;
newmap := FALSE;
END IF;
END;
ELSIF current_op = INSERT_OP THEN
IF current_clauses[SET_CLAUSE] IS NULL THEN
RAISE EXCEPTION 'missing SELECT or VALUES clause in INSERT operation';
END IF;
exec_cmds := exec_cmds || ARRAY [
'U', 'INSERT INTO db4ai.v' || s_id
|| ' ' || current_clauses[SET_CLAUSE] -- generic SQL
|| CASE WHEN current_clauses[FROM_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[FROM_CLAUSE] END
|| CASE WHEN current_clauses[WHERE_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[WHERE_CLAUSE] END ];
ELSIF current_op = DELETE_OP THEN
exec_cmds := exec_cmds || ARRAY [
'U', 'DELETE FROM db4ai.v' || s_id || ' AS ' || current_clauses[AS_CLAUSE]
|| CASE WHEN current_clauses[FROM_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[FROM_CLAUSE] END -- USING
|| CASE WHEN current_clauses[WHERE_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[WHERE_CLAUSE] END ];
ELSIF current_op = UPDATE_OP THEN
command_str := NULL; -- stores DDL statement for adding shadow columns
IF current_clauses[SET_CLAUSE] IS NULL THEN
RAISE EXCEPTION 'missing SET clause in UPDATE operation';
END IF;
-- extract updated fields and check their mapping
FOR pattern IN
SELECT coalesce(m[1], replace(m[2],'""','"'))
FROM regexp_matches(current_clauses[SET_CLAUSE],
'([^\s"]+)\s*=|"((?:[^"]*"")*[^"]*)"\s*=','g') m
LOOP
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
IF pattern = mapping[idx] THEN
-- ADD a private column (only CSS)
IF mapping[idx-2] IS NULL THEN
command_str := CASE WHEN command_str IS NULL
THEN 'ALTER TABLE db4ai.t' || m_id
ELSE command_str || ',' END
|| ' ADD f' || c_id || ' '
|| format_type(atttypid, atttypmod) FROM pg_catalog.pg_attribute
WHERE attrelid = ('db4ai.t' || m_id)::regclass AND attname = mapping[idx-1];
mapping[idx-2] := 'f' || c_id;
newmap := TRUE;
c_id := c_id + 1;
END IF;
pattern := NULL;
EXIT;
END IF;
END LOOP;
-- check if pattern was mapped
IF pattern IS NOT NULL THEN
RAISE EXCEPTION 'unable to map field "%" to backing table in UPDATE operation: %',
pattern, current_clauses[SET_CLAUSE];
END IF;
END LOOP;
-- append the DDL statement for the backing table for adding shadow columns (if any)
IF command_str IS NOT NULL THEN
exec_cmds := exec_cmds || ARRAY [ 'O', command_str ];
END IF;
IF newmap THEN
-- generate and append grant, create view and rewrite rules for new snapshot
exec_cmds := (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,
CURRENT_USER, NULL, exec_cmds, mapping)).i_exec_cmds;
newmap := FALSE;
END IF;
exec_cmds := exec_cmds || ARRAY [
'U', 'UPDATE db4ai.v' || s_id || ' AS ' || current_clauses[AS_CLAUSE]
|| ' ' || current_clauses[SET_CLAUSE]
|| CASE WHEN current_clauses[FROM_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[FROM_CLAUSE] END
|| CASE WHEN current_clauses[WHERE_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[WHERE_CLAUSE] END ];
END IF;
END IF;
current_op := next_op;
next_op := NULL;
-- restore ALTER clause for ADD / DROP without 'ALTER' keyword, else reset to NULL
current_clauses := next_clauses;
next_clauses := NULL;
END LOOP;
-- compute final version string
IF i_vers IS NULL OR i_vers = '' THEN
BEGIN
vers_arr := regexp_split_to_array(p_name_vers[2], CASE s_vers_sep WHEN '.' THEN '\.' ELSE s_vers_sep END);
IF array_length(vers_arr, 1) <> 3 OR array_length(vers_arr, 2) IS NOT NULL OR
vers_arr[1] ~ '[^0-9]' OR vers_arr[2] ~ '[^0-9]' OR vers_arr[3] ~ '[^0-9]' THEN
RAISE EXCEPTION 'illegal version format';
END IF;
IF ops_arr[ALTER_OP] THEN
vers_arr[1] := vers_arr[1] + 1;
vers_arr[2] := 0;
vers_arr[3] := 0;
ELSIF ops_arr[INSERT_OP] OR ops_arr[DELETE_OP] THEN
vers_arr[2] := vers_arr[2] + 1;
vers_arr[3] := 0;
ELSE
vers_arr[3] := vers_arr[3] + 1;
END IF;
i_vers := s_vers_del || array_to_string(vers_arr, s_vers_sep);
EXCEPTION WHEN OTHERS THEN
RAISE EXCEPTION 'parent has nonstandard version %. i_vers cannot be null or empty', p_name_vers[2]
USING HINT = 'provide custom version using i_vers parameter for new snapshot';
END;
ELSE
i_vers := replace(i_vers, chr(2), s_vers_sep);
IF LEFT(i_vers, 1) <> s_vers_del THEN
i_vers := s_vers_del || i_vers;
ELSIF char_length(i_vers) < 2 THEN
RAISE EXCEPTION 'illegal i_vers: ''%''', s_vers_del;
END IF;
IF strpos(substr(i_vers, 2), s_vers_del) > 0 THEN
RAISE EXCEPTION 'i_vers may contain only one single, leading ''%'' character', s_vers_del
USING HINT = 'specify snapshot version as [' || s_vers_del || ']x' || s_vers_sep || 'y' || s_vers_sep || 'z or ['
|| s_vers_del || ']label with optional, leading ''' || s_vers_del || '''';
END IF;
END IF;
IF char_length(p_name_vers[1] || i_vers) > 63 THEN
RAISE EXCEPTION 'snapshot name too long: ''%''', p_name_vers[1] || i_vers;
ELSE
s_name := p_name_vers[1] || i_vers;
END IF;
-- the final name of the snapshot
qual_name := quote_ident(i_schema) || '.' || quote_ident(s_name);
-- check for duplicate snapshot
IF 0 < (SELECT COUNT(0) FROM db4ai.snapshot WHERE schema = i_schema AND name = s_name) THEN
RAISE EXCEPTION 'snapshot % already exists' , qual_name;
END IF;
IF s_mode = 'MSS' THEN
exec_cmds := exec_cmds || ARRAY [
'O', 'COMMENT ON TABLE db4ai.t' || s_id || ' IS ''snapshot backing table, root is ' || qual_name || '''' ];
END IF;
-- Execute the queries
RAISE NOTICE E'accumulated commands:\n%', array_to_string(exec_cmds, E'\n');
DECLARE
idx INTEGER := 1; -- loop counter, cannot use FOR .. iterator
BEGIN
LOOP EXIT WHEN idx = 1 + array_length(exec_cmds, 1);
WHILE exec_cmds[idx][1] = 'U' LOOP
-- RAISE NOTICE 'user executing: %', exec_cmds[idx][2];
DECLARE
e_message TEXT; -- exception message
BEGIN
EXECUTE exec_cmds[idx][2];
idx := idx + 1;
EXCEPTION WHEN undefined_table THEN
GET STACKED DIAGNOSTICS e_message = MESSAGE_TEXT;
-- during function invocation, search path is redirected to {pg_temp, pg_catalog, function_schema} and becomes immutable
RAISE INFO 'could not resolve relation % using system-defined "search_path" setting during function invocation: ''%''',
substr(e_message, 10, 1 + strpos(substr(e_message,11), '" does not exist')),
array_to_string(current_schemas(TRUE),', ')
USING HINT = 'snapshots require schema-qualified table references, e.g. schema_name.table_name';
RAISE;
END;
END LOOP;
IF idx < array_length(exec_cmds, 1) AND (exec_cmds[idx][1] IS NULL OR exec_cmds[idx][1] <> 'O') THEN -- this should never happen
RAISE EXCEPTION 'prepare snapshot internal error1: % %', idx, exec_cmds[idx];
END IF;
-- execute owner statements (if any) and epilogue
idx := (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,
CURRENT_USER, idx, exec_cmds)).i_idx;
END LOOP;
END;
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
s_uv_proj := s_uv_proj || quote_ident(mapping[idx]) || ',';
END LOOP;
-- create custom view, owned by current user
EXECUTE 'CREATE VIEW ' || qual_name || ' WITH(security_barrier) AS SELECT '|| rtrim(s_uv_proj, ',') || ' FROM db4ai.v' || s_id;
EXECUTE 'COMMENT ON VIEW ' || qual_name || ' IS ''snapshot view backed by db4ai.v' || s_id
|| CASE WHEN length(i_comment) > 0 THEN ' comment is "' || i_comment || '"' ELSE '' END || '''';
EXECUTE 'ALTER VIEW ' || qual_name || ' OWNER TO ' || CURRENT_USER;
-- return final snapshot name
res := ROW(i_schema, s_name);
return res;
END;
$$;
COMMENT ON FUNCTION db4ai.prepare_snapshot() IS 'Prepare snapshot from existing for data curation';

View File

@ -0,0 +1,154 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* publish.sql
* Publish DB4AI.Snapshot functionality.
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/snapshots/publish.sql
*
* -------------------------------------------------------------------------
*/
CREATE OR REPLACE FUNCTION db4ai.manage_snapshot_internal(
IN i_schema NAME, -- snapshot namespace
IN i_name NAME, -- snapshot name
IN publish BOOLEAN -- publish or archive
)
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY DEFINER SET search_path = pg_catalog, pg_temp SET client_min_messages TO ERROR
AS $$
DECLARE
s_mode VARCHAR(3); -- current snapshot mode
s_vers_del CHAR; -- snapshot version delimiter, default '@'
s_vers_sep CHAR; -- snapshot version separator, default '.'
s_name_vers TEXT[]; -- split snapshot id into name and version
e_stack_act TEXT; -- current stack for validation
res db4ai.snapshot_name; -- composite result
BEGIN
BEGIN
RAISE EXCEPTION 'SECURITY_STACK_CHECK';
EXCEPTION WHEN OTHERS THEN
GET STACKED DIAGNOSTICS e_stack_act = PG_EXCEPTION_CONTEXT;
IF CURRENT_SCHEMA = 'db4ai' THEN
e_stack_act := replace(e_stack_act, ' archive_snapshot(', ' db4ai.archive_snapshot(');
e_stack_act := replace(e_stack_act, ' publish_snapshot(', ' db4ai.publish_snapshot(');
END IF;
IF e_stack_act NOT SIMILAR TO '%PL/pgSQL function db4ai.(archive|publish)_snapshot\(name,name\) line 11 at assignment%'
THEN
RAISE EXCEPTION 'direct call to db4ai.manage_snapshot_internal(name,name,boolean) is not allowed'
USING HINT = 'call public interface db4ai.(publish|archive)_snapshot instead';
END IF;
END;
-- obtain active message level
BEGIN
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
RAISE INFO 'effective client_min_messages is ''%''', upper(current_setting('db4ai.message_level'));
EXCEPTION WHEN OTHERS THEN
END;
-- obtain relevant configuration parameters
BEGIN
s_mode := upper(current_setting('db4ai_snapshot_mode'));
EXCEPTION WHEN OTHERS THEN
s_mode := 'MSS';
END;
IF s_mode NOT IN ('CSS', 'MSS') THEN
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
END IF;
-- obtain relevant configuration parameters
BEGIN
s_vers_del := current_setting('db4ai_snapshot_version_delimiter');
EXCEPTION WHEN OTHERS THEN
s_vers_del := '@';
END;
BEGIN
s_vers_sep := upper(current_setting('db4ai_snapshot_version_separator'));
EXCEPTION WHEN OTHERS THEN
s_vers_sep := '.';
END;
-- check all input parameters
IF i_name IS NULL OR i_name = '' THEN
RAISE EXCEPTION 'i_name cannot be NULL or empty';
ELSE
i_name := replace(i_name, chr(1), s_vers_del);
i_name := replace(i_name, chr(2), s_vers_sep);
s_name_vers := regexp_split_to_array(i_name, s_vers_del);
IF array_length(s_name_vers, 1) <> 2 OR array_length(s_name_vers, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_name must contain exactly one ''%'' character', s_vers_del
USING HINT = 'reference a snapshot using the format: snapshot_name' || s_vers_del || 'version';
END IF;
END IF;
UPDATE db4ai.snapshot SET published = publish, archived = NOT publish WHERE schema = i_schema AND name = i_name;
IF SQL%ROWCOUNT = 0 THEN
RAISE EXCEPTION 'snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_name);
END IF;
res := ROW(i_schema, i_name);
return res;
END;
$$;
CREATE OR REPLACE FUNCTION db4ai.archive_snapshot(
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER
IN i_name NAME -- snapshot name
)
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
AS $$
DECLARE
res db4ai.snapshot_name; -- composite result
BEGIN
IF i_schema IS NULL OR i_schema = '' THEN
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
END IF;
-- return archived snapshot name
res := db4ai.manage_snapshot_internal(i_schema, i_name, FALSE);
return res;
END;
$$;
COMMENT ON FUNCTION db4ai.archive_snapshot() IS 'Archive snapshot for preventing usage in model training';
CREATE OR REPLACE FUNCTION db4ai.publish_snapshot(
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
IN i_name NAME -- snapshot name
)
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
AS $$
DECLARE
res db4ai.snapshot_name; -- composite result
BEGIN
IF i_schema IS NULL OR i_schema = '' THEN
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
END IF;
-- return published snapshot name
res := db4ai.manage_snapshot_internal(i_schema, i_name, TRUE);
return res;
END;
$$;
COMMENT ON FUNCTION db4ai.publish_snapshot() IS 'Publish snapshot for allowing usage in model training';

View File

@ -0,0 +1,199 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* purge.sql
* Purge DB4AI.Snapshot functionality.
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/snapshots/purge.sql
*
* -------------------------------------------------------------------------
*/
CREATE OR REPLACE FUNCTION db4ai.purge_snapshot_internal(
IN i_schema NAME, -- snapshot namespace
IN i_name NAME -- snapshot name
)
RETURNS VOID LANGUAGE plpgsql SECURITY DEFINER SET search_path = pg_catalog, pg_temp
AS $$
DECLARE
s_id BIGINT; -- snapshot id
p_id BIGINT; -- parent id
m_id BIGINT; -- matrix id
o_id BIGINT[]; -- other snapshot ids in same backing table
pushed_cmds TEXT[]; -- commands to be pushed to descendants
pushed_comment TEXT; -- comments to be pushed to descendants
drop_cols NAME[]; -- orphaned columns
e_stack_act TEXT; -- current stack for validation
affected BIGINT; -- number of affected rows;
BEGIN
BEGIN
RAISE EXCEPTION 'SECURITY_STACK_CHECK';
EXCEPTION WHEN OTHERS THEN
GET STACKED DIAGNOSTICS e_stack_act = PG_EXCEPTION_CONTEXT;
IF CURRENT_SCHEMA = 'db4ai' THEN
e_stack_act := replace(e_stack_act, 'ion pur', 'ion db4ai.pur');
END IF;
IF e_stack_act NOT LIKE 'referenced column: purge_snapshot_internal
SQL statement "SELECT db4ai.purge_snapshot_internal(i_schema, i_name)"
PL/pgSQL function db4ai.purge_snapshot(name,name) line 62 at PERFORM%'
THEN
RAISE EXCEPTION 'direct call to db4ai.purge_snapshot_internal(name,name) is not allowed'
USING HINT = 'call public interface db4ai.purge_snapshot instead';
END IF;
END;
-- check if snapshot exists
BEGIN
SELECT commands, comment, id, parent_id, matrix_id FROM db4ai.snapshot WHERE schema = i_schema AND name = i_name
INTO STRICT pushed_cmds, pushed_comment, s_id, p_id, m_id;
EXCEPTION WHEN NO_DATA_FOUND THEN
RAISE EXCEPTION 'snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_name);
END;
-- update descendants, if any
UPDATE db4ai.snapshot SET
parent_id = p_id,
commands = pushed_cmds || commands,
comment = CASE WHEN pushed_comment IS NULL THEN comment
WHEN comment IS NULL THEN pushed_comment
ELSE pushed_comment || ' | ' || comment END
WHERE parent_id = s_id;
IF p_id IS NULL AND SQL%ROWCOUNT > 0 THEN
RAISE EXCEPTION 'cannot purge root snapshot ''%.%'' having dependent snapshots', quote_ident(i_schema), quote_ident(i_name)
USING HINT = 'purge all dependent snapshots first';
END IF;
IF m_id IS NULL THEN
EXECUTE 'DROP VIEW db4ai.v' || s_id;
EXECUTE 'DROP TABLE db4ai.t' || s_id;
RAISE NOTICE 'PURGE_SNAPSHOT: MSS backing table dropped';
ELSE
SELECT array_agg(id) FROM db4ai.snapshot WHERE matrix_id = m_id AND id <> s_id INTO STRICT o_id;
IF o_id IS NULL OR array_length(o_id, 1) = 0 THEN
EXECUTE 'DROP VIEW db4ai.v' || s_id;
EXECUTE 'DROP TABLE db4ai.t' || m_id;
RAISE NOTICE 'PURGE_SNAPSHOT: CSS backing table dropped';
ELSE
EXECUTE 'DELETE FROM db4ai.t' || m_id || ' WHERE _' || s_id || ' AND NOT (_' || array_to_string(o_id, ' OR _') || ')';
GET DIAGNOSTICS affected = ROW_COUNT;
SELECT array_agg(quote_ident(column_name))
FROM ( SELECT column_name
FROM information_schema.columns
WHERE table_schema = 'db4ai' AND table_name = ANY ( ('{v' || array_to_string(s_id || o_id, ',v') || '}')::NAME[] )
GROUP BY column_name
HAVING SUM(CASE table_name WHEN 'v' || s_id THEN 0 ELSE 1 END) = 0 )
INTO STRICT drop_cols;
EXECUTE 'DROP VIEW db4ai.v' || s_id;
IF TRUE OR drop_cols IS NULL THEN
EXECUTE 'ALTER TABLE db4ai.t' || m_id || ' DROP _' || s_id;
RAISE NOTICE 'PURGE_SNAPSHOT: orphaned rows dropped: %, orphaned columns dropped: none', affected;
ELSE
EXECUTE 'ALTER TABLE db4ai.t' || m_id || ' DROP _' || s_id || ', DROP ' || array_to_string(drop_cols, ', DROP ');
RAISE NOTICE 'PURGE_SNAPSHOT: orphaned rows dropped: %, orphaned columns dropped: %', affected, drop_cols;
END IF;
END IF;
END IF;
DELETE FROM db4ai.snapshot WHERE schema = i_schema AND name = i_name;
IF SQL%ROWCOUNT = 0 THEN
-- checked before, this should never happen
RAISE INFO 'snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_name);
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION db4ai.purge_snapshot(
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
IN i_name NAME -- snapshot name
)
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
AS $$
DECLARE
s_mode VARCHAR(3); -- current snapshot mode
s_vers_del CHAR; -- snapshot version delimiter, default '@'
s_vers_sep CHAR; -- snapshot version separator, default '.'
s_name_vers TEXT[]; -- split full name into name and version
res db4ai.snapshot_name; -- composite result
BEGIN
-- obtain active message level
BEGIN
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
RAISE INFO 'effective client_min_messages is ''%''', upper(current_setting('db4ai.message_level'));
EXCEPTION WHEN OTHERS THEN
END;
-- obtain active snapshot mode
BEGIN
s_mode := upper(current_setting('db4ai_snapshot_mode'));
EXCEPTION WHEN OTHERS THEN
s_mode := 'MSS';
END;
IF s_mode NOT IN ('CSS', 'MSS') THEN
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
END IF;
-- obtain relevant configuration parameters
BEGIN
s_vers_del := current_setting('db4ai_snapshot_version_delimiter');
EXCEPTION WHEN OTHERS THEN
s_vers_del := '@';
END;
BEGIN
s_vers_sep := upper(current_setting('db4ai_snapshot_version_separator'));
EXCEPTION WHEN OTHERS THEN
s_vers_sep := '.';
END;
-- check all input parameters
IF i_schema IS NULL OR i_schema = '' THEN
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
END IF;
IF i_name IS NULL OR i_name = '' THEN
RAISE EXCEPTION 'i_name cannot be NULL or empty';
ELSE
i_name := replace(i_name, chr(1), s_vers_del);
i_name := replace(i_name, chr(2), s_vers_sep);
s_name_vers := regexp_split_to_array(i_name, s_vers_del);
IF array_length(s_name_vers, 1) <> 2 OR array_length(s_name_vers, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_name must contain exactly one ''%'' character', s_vers_del
USING HINT = 'reference a snapshot using the format: snapshot_name' || s_vers_del || 'version';
END IF;
END IF;
BEGIN
EXECUTE 'DROP VIEW ' || quote_ident(i_schema) || '.' || quote_ident(i_name);
EXCEPTION WHEN OTHERS THEN
END;
PERFORM db4ai.purge_snapshot_internal(i_schema, i_name);
-- return purged snapshot name
res := ROW(i_schema, i_name);
return res;
END;
$$;
COMMENT ON FUNCTION db4ai.purge_snapshot() IS 'Purge a snapshot and reclaim occupied storage';

View File

@ -0,0 +1,269 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* sample.sql
* Sample DB4AI.Snapshot functionality.
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/snapshots/sample.sql
*
* -------------------------------------------------------------------------
*/
CREATE OR REPLACE FUNCTION db4ai.sample_snapshot(
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
IN i_parent NAME, -- parent snapshot name
IN i_sample_infixes NAME[], -- sample snapshot name infixes
IN i_sample_ratios NUMBER[], -- size of each sample, as a ratio of the parent set
IN i_stratify NAME[] DEFAULT NULL, -- stratification fields
IN i_sample_comments TEXT[] DEFAULT NULL -- sample snapshot descriptions
)
RETURNS SETOF db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
AS $$
DECLARE
s_id BIGINT; -- snapshot id
p_id BIGINT; -- parent id
m_id BIGINT; -- matrix id
r_id BIGINT; -- root id
s_mode VARCHAR(3); -- current snapshot mode
s_vers_del CHAR; -- snapshot version delimiter, default '@'
s_vers_sep CHAR; -- snapshot version separator, default '.'
s_sv_proj TEXT; -- snapshot system view projection list
s_bt_proj TEXT; -- snapshot backing table projection list
s_bt_dist TEXT; -- DISTRIBUTE BY clause for creating backing table
s_uv_proj TEXT; -- snapshot user view projection list
p_sv_proj TEXT; -- parent snapshot system view projection list
p_name_vers TEXT[]; -- split full parent name into name and version
stratify_count BIGINT[]; -- count per stratification class
exec_cmds TEXT[]; -- commands for execution
qual_name TEXT; -- qualified snapshot name
mapping NAME[]; -- mapping user column names to backing column names
s_name db4ai.snapshot_name; -- snapshot sample name
BEGIN
-- obtain active message level
BEGIN
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
RAISE INFO 'effective client_min_messages is %', upper(current_setting('db4ai.message_level'));
EXCEPTION WHEN OTHERS THEN
END;
-- obtain active snapshot mode
BEGIN
s_mode := upper(current_setting('db4ai_snapshot_mode'));
EXCEPTION WHEN OTHERS THEN
s_mode := 'MSS';
END;
IF s_mode NOT IN ('CSS', 'MSS') THEN
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
END IF;
-- obtain relevant configuration parameters
BEGIN
s_vers_del := current_setting('db4ai_snapshot_version_delimiter');
EXCEPTION WHEN OTHERS THEN
s_vers_del := '@';
END;
BEGIN
s_vers_sep := upper(current_setting('db4ai_snapshot_version_separator'));
EXCEPTION WHEN OTHERS THEN
s_vers_sep := '.';
END;
-- check all input parameters
IF i_schema IS NULL OR i_schema = '' THEN
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
END IF;
IF i_parent IS NULL OR i_parent = '' THEN
RAISE EXCEPTION 'i_parent cannot be NULL or empty';
ELSE
i_parent := replace(i_parent, chr(1), s_vers_del);
i_parent := replace(i_parent, chr(2), s_vers_sep);
p_name_vers := regexp_split_to_array(i_parent, s_vers_del);
IF array_length(p_name_vers, 1) <> 2 OR array_length(p_name_vers, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_parent must contain exactly one ''%'' character', s_vers_del
USING HINT = 'reference a snapshot using the format: snapshot_name' || s_vers_del || 'version';
END IF;
END IF;
-- check if parent exists
BEGIN
SELECT id, matrix_id, root_id FROM db4ai.snapshot WHERE schema = i_schema AND name = i_parent INTO STRICT p_id, m_id, r_id;
EXCEPTION WHEN NO_DATA_FOUND THEN
RAISE EXCEPTION 'parent snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_parent);
END;
IF i_sample_infixes IS NULL OR array_length(i_sample_infixes, 1) IS NULL OR array_length(i_sample_infixes, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_sample_infixes array malformed'
USING HINT = 'pass sample infixes as NAME[] literal, e.g. ''{_train, _test}''';
END IF;
IF i_sample_ratios IS NULL OR array_length(i_sample_ratios, 1) IS NULL OR array_length(i_sample_ratios, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_sample_ratios array malformed'
USING HINT = 'pass sample percentages as NUMBER[] literal, e.g. ''{.8, .2}''';
END IF;
IF array_length(i_sample_infixes, 1) <> array_length(i_sample_ratios, 1) THEN
RAISE EXCEPTION 'i_sample_infixes and i_sample_ratios array length mismatch';
END IF;
IF i_stratify IS NOT NULL THEN
IF array_length(i_stratify, 1) IS NULL OR array_length(i_stratify, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_stratify array malformed'
USING HINT = 'pass stratification field names as NAME[] literal, e.g. ''{color, size}''';
END IF;
EXECUTE 'SELECT ARRAY[COUNT(DISTINCT ' || array_to_string(i_stratify, '), COUNT(DISTINCT ') || ')] FROM db4ai.v' || p_id
INTO STRICT stratify_count;
IF stratify_count IS NULL THEN
RAISE EXCEPTION 'sample snapshot internal error2: %', p_id;
END IF;
SELECT array_agg(ordered) FROM (SELECT unnest(i_stratify) ordered ORDER BY unnest(stratify_count)) INTO STRICT i_stratify;
IF i_stratify IS NULL THEN
RAISE EXCEPTION 'sample snapshot internal error3';
END IF;
END IF;
IF i_sample_comments IS NOT NULL THEN
IF array_length(i_sample_comments, 1) IS NULL OR array_length(i_sample_comments, 2) IS NOT NULL THEN
RAISE EXCEPTION 'i_sample_comments array malformed'
USING HINT = 'pass sample comments as TEXT[] literal, e.g. ''{comment 1, comment 2}''';
ELSIF array_length(i_sample_infixes, 1) <> array_length(i_sample_comments, 1) THEN
RAISE EXCEPTION 'i_sample_infixes and i_sample_comments array length mismatch';
END IF;
END IF;
-- extract normalized projection list (private: nullable, shared: not null, user_cname: not null)
p_sv_proj := substring(pg_get_viewdef('db4ai.v' || p_id), '^SELECT (.*), t[0-9]+\.xc_node_id, t[0-9]+\.ctid FROM.*$');
mapping := array(SELECT unnest(ARRAY[ m[1], m[2], coalesce(m[3], replace(m[4],'""','"'))]) FROM regexp_matches(p_sv_proj,
'(?:COALESCE\(t[0-9]+\.(f[0-9]+), )?t[0-9]+\.(f[0-9]+)(?:\))? AS (?:([^\s",]+)|"((?:[^"]*"")*[^"]*)")', 'g') m);
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
IF s_mode = 'MSS' THEN
s_sv_proj := s_sv_proj || coalesce(mapping[idx-2], mapping[idx-1]) || ' AS ' || quote_ident(mapping[idx]) || ',';
s_bt_proj := s_bt_proj || quote_ident(mapping[idx]) || ' AS ' || coalesce(mapping[idx-2], mapping[idx-1]) || ',';
ELSIF s_mode = 'CSS' THEN
IF mapping[idx-2] IS NULL THEN
s_sv_proj := s_sv_proj || mapping[idx-1] || ' AS ' || quote_ident(mapping[idx]) || ',';
ELSE
s_sv_proj := s_sv_proj || 'coalesce(' || mapping[idx-2] || ',' || mapping[idx-1] || ') AS ' || quote_ident(mapping[idx]) || ',';
END IF;
END IF;
s_uv_proj := s_uv_proj || quote_ident(mapping[idx]) || ',';
END LOOP;
s_bt_dist := getdistributekey('db4ai.t' || coalesce(m_id, p_id));
s_bt_dist := CASE WHEN s_bt_dist IS NULL
THEN ' DISTRIBUTE BY REPLICATION'
ELSE ' DISTRIBUTE BY HASH(' || s_bt_dist || ')' END; s_bt_dist = '';
FOR i IN 1 .. array_length(i_sample_infixes, 1) LOOP
IF i_sample_infixes[i] IS NULL THEN
RAISE EXCEPTION 'i_sample_infixes array contains NULL values';
END IF;
IF i_sample_ratios[i] IS NULL THEN
RAISE EXCEPTION 'i_sample_ratios array contains NULL values';
END IF;
qual_name := p_name_vers[1] || i_sample_infixes[i] || s_vers_del || p_name_vers[2];
IF char_length(qual_name) > 63 THEN
RAISE EXCEPTION 'sample snapshot name too long: ''%''', qual_name;
ELSE
s_name := (i_schema, qual_name);
qual_name := quote_ident(s_name.schema) || '.' || quote_ident(s_name.name);
END IF;
IF i_sample_ratios[i] < 0 OR i_sample_ratios[i] > 1 THEN
RAISE EXCEPTION 'sample ratio must be between 0 and 1';
END IF;
-- SELECT nextval('db4ai.snapshot_sequence') INTO STRICT s_id;
SELECT MAX(id)+1 FROM db4ai.snapshot INTO STRICT s_id; -- openGauss BUG: cannot create sequences in initdb
-- check for duplicate snapshot
IF 0 < (SELECT COUNT(*) FROM db4ai.snapshot WHERE schema = s_name.schema AND name = s_name.name) THEN
RAISE EXCEPTION 'snapshot % already exists' , qual_name;
END IF;
-- SET seed TO 0.444;
-- setseed(0.444);
-- dbms_random.seed(0.888);
-- create / upgrade + prepare target snapshots for SQL DML/DDL operations
IF s_mode = 'MSS' THEN
exec_cmds := ARRAY [
-- extract and propagate DISTRIBUTE BY from root MSS snapshot
[ 'O','CREATE TABLE db4ai.t' || s_id || ' WITH (orientation = column, compression = low)' || s_bt_dist
|| ' AS SELECT ' || rtrim(s_bt_proj, ',') || ' FROM db4ai.v' || p_id || ' WHERE random() <= ' || i_sample_ratios[i] ],
-- || ' AS SELECT ' || rtrim(s_bt_proj, ',') || ' FROM db4ai.v' || p_id || ' WHERE dbms_random.value(0, 1) <= ' || i_sample_ratios[i],
[ 'O', 'COMMENT ON TABLE db4ai.t' || s_id || ' IS ''snapshot backing table, root is ' || qual_name || '''' ],
[ 'O', 'CREATE VIEW db4ai.v' || s_id || ' WITH(security_barrier) AS SELECT ' || s_sv_proj || ' xc_node_id, ctid FROM db4ai.t' || s_id ]];
ELSIF s_mode = 'CSS' THEN
IF m_id IS NULL THEN
exec_cmds := ARRAY [
[ 'O', 'UPDATE db4ai.snapshot SET matrix_id = ' || p_id || ' WHERE schema = ''' || i_schema || ''' AND name = '''
|| i_parent || '''' ],
[ 'O', 'ALTER TABLE db4ai.t' || p_id || ' ADD _' || p_id || ' BOOLEAN NOT NULL DEFAULT TRUE' ],
[ 'O', 'ALTER TABLE db4ai.t' || p_id || ' ALTER _' || p_id || ' SET DEFAULT FALSE' ],
[ 'O', 'CREATE OR REPLACE VIEW db4ai.v' || p_id || ' WITH(security_barrier) AS SELECT ' || p_sv_proj || ', xc_node_id, ctid FROM db4ai.t'
|| p_id || ' WHERE _' || p_id ]];
m_id := p_id;
END IF;
exec_cmds := exec_cmds || ARRAY [
[ 'O', 'ALTER TABLE db4ai.t' || m_id || ' ADD _' || s_id || ' BOOLEAN NOT NULL DEFAULT FALSE' ],
[ 'O', 'UPDATE db4ai.t' || m_id || ' SET _' || s_id || ' = TRUE WHERE _' || p_id || ' AND random() <= '
-- [ 'O', 'UPDATE db4ai.t' || m_id || ' SET _' || s_id || ' = TRUE WHERE _' || p_id || ' AND dbms_random.value(0, 1) <= '
|| i_sample_ratios[i] ],
[ 'O', 'CREATE VIEW db4ai.v' || s_id || ' WITH(security_barrier) AS SELECT ' || s_sv_proj || ' xc_node_id, ctid FROM db4ai.t' || m_id
|| ' WHERE _' || s_id ]];
END IF;
-- || ' AS SELECT ' || proj_list || ' FROM '
-- || '(SELECT *, count(*) OVER() _cnt, row_number() OVER('
-- || CASE WHEN i_stratify IS NOT NULL THEN 'ORDER BY ' || array_to_string(i_stratify, ', ') END
-- || ') _row FROM db4ai.v' || p_id
-- || ') WHERE round(_row/100 = 0
--|| ' TABLESAMPLE SYSTEM ( ' || i_sample_ratios[i] || ') REPEATABLE (888)';
--SELECT * FROM (SELECT *, count(*) over()_ cnt, row_number() OVER(ORDER BY COLOR) _row FROM t) WHERE _row % (cnt/ 10) = 0;
-- Execute the queries
RAISE NOTICE E'accumulated commands:\n%', array_to_string(exec_cmds, E'\n');
IF 1 + array_length(exec_cmds, 1) <> (db4ai.prepare_snapshot_internal(
s_id, p_id, m_id, r_id, s_name.schema, s_name.name,
ARRAY [ 'SAMPLE ' || i_sample_infixes[i] || ' ' || i_sample_ratios[i] ||
CASE WHEN i_stratify IS NULL THEN '' ELSE ' ' || i_stratify::TEXT END ],
i_sample_comments[i], CURRENT_USER, 1, exec_cmds)).i_idx THEN
RAISE EXCEPTION 'sample snapshot internal error1';
END IF;
-- create custom view, owned by current user
EXECUTE 'CREATE VIEW ' || qual_name || ' WITH(security_barrier) AS SELECT ' || rtrim(s_uv_proj, ',') || ' FROM db4ai.v' || s_id;
EXECUTE 'COMMENT ON VIEW ' || qual_name || ' IS ''snapshot view backed by db4ai.v' || s_id
|| CASE WHEN length(i_sample_comments[i]) > 0 THEN ' comment is "' || i_sample_comments[i] || '"' ELSE '' END || '''';
EXECUTE 'ALTER VIEW ' || qual_name || ' OWNER TO ' || CURRENT_USER;
exec_cmds := NULL;
RETURN NEXT s_name;
END LOOP;
END;
$$;
COMMENT ON FUNCTION db4ai.sample_snapshot() IS 'Create samples from a snapshot';

View File

@ -0,0 +1,74 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* schema.sql
* Schema for DB4AI.Snapshot functionality.
*
* IDENTIFICATION
* src/gausskernel/dbmind/db4ai/snapshots/schema.sql
*
* -------------------------------------------------------------------------
*/
GRANT USAGE ON SCHEMA db4ai TO PUBLIC;
--CREATE SEQUENCE db4ai.snapshot_sequence; -- openGauss BUG: cannot create sequences in initdb
--GRANT USAGE ON SEQUENCE db4ai.snapshot_sequence TO db4ai;
CREATE TYPE db4ai.snapshot_name AS ("schema" NAME, "name" NAME); -- openGauss BUG: array type not created during initdb
CREATE TABLE IF NOT EXISTS db4ai.snapshot
(
id BIGINT UNIQUE, -- snapshot id (surrogate key)
parent_id BIGINT, -- parent snapshot id (references snapshot.id)
matrix_id BIGINT, -- matrix id from CSS snapshots, else NULL
-- (references snapshot.id)
root_id BIGINT, -- id of the initial snapshot, constructed via
-- db4ai.create_snapshot() from operational data
-- (references snapshot.id)
schema NAME NOT NULL, -- schema where the snapshot view is exported
name NAME NOT NULL, -- name of the snapshot, including version postfix
owner NAME NOT NULL, -- name of the user who created this snapshot
commands TEXT[] NOT NULL, -- complete list of SQL statements documenting how
-- to generate this snapshot from its ancestor
comment TEXT, -- description of the snapshot
published BOOLEAN NOT NULL DEFAULT FALSE, -- TRUE, iff the snapshot is currently published
archived BOOLEAN NOT NULL DEFAULT FALSE, -- TRUE, iff the snapshot is currently archived
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP,-- timestamp of snapshot creation date
row_count BIGINT NOT NULL, -- number of rows in this snapshot
PRIMARY KEY (schema, name)
) /* DISTRIBUTE BY REPLICATION */;
COMMENT ON TABLE db4ai.snapshot IS 'system catalog of meta-data on DB4AI snapshots';
COMMENT ON COLUMN db4ai.snapshot.id IS 'snapshot id (surrogate key)';
COMMENT ON COLUMN db4ai.snapshot.parent_id IS 'parent snapshot id (references snapshot.id)';
COMMENT ON COLUMN db4ai.snapshot.matrix_id IS E'matrix id from CSS snapshots, else NULL\n'
'(references snapshot.id)';
COMMENT ON COLUMN db4ai.snapshot.root_id IS E'id of the initial snapshot, constructed via\n'
'db4ai.create_snapshot() from operational data\n'
'(references snapshot.id)';
COMMENT ON COLUMN db4ai.snapshot.schema IS 'schema where the snapshot view is exported';
COMMENT ON COLUMN db4ai.snapshot.name IS 'name of the snapshot, including version postfix';
COMMENT ON COLUMN db4ai.snapshot.owner IS 'name of the user who created this snapshot';
COMMENT ON COLUMN db4ai.snapshot.commands IS E'complete list of SQL statements documenting how\n'
'to generate this snapshot from its ancestor';
COMMENT ON COLUMN db4ai.snapshot.comment IS 'description of the snapshot';
COMMENT ON COLUMN db4ai.snapshot.published IS 'TRUE, iff the snapshot is currently published';
COMMENT ON COLUMN db4ai.snapshot.archived IS 'TRUE, iff the snapshot is currently archived';
COMMENT ON COLUMN db4ai.snapshot.created IS 'timestamp of snapshot creation date';
-- public read-only access to snapshot catalog
REVOKE ALL PRIVILEGES ON db4ai.snapshot FROM PUBLIC;
GRANT SELECT ON db4ai.snapshot TO PUBLIC;

View File

@ -23,6 +23,7 @@
#include "catalog/gs_obsscaninfo.h"
#include "catalog/pg_obsscaninfo.h"
#include "catalog/pg_type.h"
#include "db4ai/create_model.h"
#include "commands/createas.h"
#include "commands/defrem.h"
#include "commands/prepare.h"
@ -87,6 +88,7 @@ THR_LOCAL explain_get_index_name_hook_type explain_get_index_name_hook = NULL;
extern TrackDesc trackdesc[];
extern sortMessage sortmessage[];
extern DestReceiver* CreateDestReceiver(CommandDest dest);
/* Array to record plan table column names, type, etc */
static const PlanTableEntry plan_table_cols[] = {{"id", PLANID, INT4OID, ANAL_OPT},
@ -653,6 +655,22 @@ static void ExplainOneQuery(
return;
}
else if (IsA(query->utilityStmt, CreateModelStmt)) {
CreateModelStmt* cm = (CreateModelStmt*) query->utilityStmt;
/*
* Create the tuple receiver object and insert hyperp it will need
*/
DestReceiverTrainModel* dest_train_model = NULL;
dest_train_model = (DestReceiverTrainModel*) CreateDestReceiver(DestTrainModel);
configure_dest_receiver_train_model(dest_train_model, (AlgorithmML) cm->algorithm, cm->model, queryString);
PlannedStmt* plan = plan_create_model(cm, queryString, params, (DestReceiver*)dest_train_model);
ExplainOnePlan(plan, into, es, queryString, dest, params);
return;
}
ExplainOneUtility(query->utilityStmt, into, es, queryString, params);
return;
@ -1928,6 +1946,7 @@ static void ExplainNode(
case T_WorkTableScan:
case T_ForeignScan:
case T_VecForeignScan:
case T_GradientDescentState:
ExplainScanTarget((Scan*)plan, es);
break;
case T_ExtensiblePlan:

View File

@ -477,6 +477,12 @@ void GetPlanNodePlainText(
*pname = "Vector Merge";
*sname = *pt_operation = "Vector Merge Join";
break;
case T_GradientDescent:
*pname = *sname = *pt_options = "Gradient Descent";
break;
case T_KMeans:
*pname = *sname = *pt_options = "K-Means";
break;
default:
*pname = *sname = *pt_operation = "?\?\?";
break;

View File

@ -33,6 +33,7 @@
#include "access/xact.h"
#include "commands/copy.h"
#include "commands/createas.h"
#include "db4ai/create_model.h"
#include "commands/matview.h"
#include "executor/functions.h"
#include "executor/spi.h"
@ -150,6 +151,10 @@ DestReceiver* CreateDestReceiver(CommandDest dest)
case DestBatchLocalRoundRobin:
case DestBatchHybrid:
return createStreamDestReceiver(dest);
case DestTrainModel:
return CreateTrainModelDestReceiver();
default:
break;
}
@ -188,6 +193,7 @@ void EndCommand(const char* commandTag, CommandDest dest)
case DestCopyOut:
case DestSQLFunction:
case DestTransientRel:
case DestTrainModel:
default:
break;
}
@ -218,6 +224,7 @@ void EndCommand_noblock(const char* commandTag, CommandDest dest)
case DestIntoRel:
case DestCopyOut:
case DestSQLFunction:
case DestTrainModel:
default:
break;
}
@ -265,6 +272,7 @@ void NullCommand(CommandDest dest)
case DestCopyOut:
case DestSQLFunction:
case DestTransientRel:
case DestTrainModel:
default:
break;
}
@ -321,6 +329,7 @@ void ReadyForQuery(CommandDest dest)
case DestCopyOut:
case DestSQLFunction:
case DestTransientRel:
case DestTrainModel:
default:
break;
}
@ -355,6 +364,7 @@ void ReadyForQuery_noblock(CommandDest dest, int timeout)
case DestIntoRel:
case DestCopyOut:
case DestSQLFunction:
case DestTrainModel:
default:
break;
}

View File

@ -102,6 +102,7 @@
#include "gs_policy/gs_policy_audit.h"
#include "gs_policy/policy_common.h"
#include "client_logic/client_logic.h"
#include "db4ai/create_model.h"
#ifdef ENABLE_MULTIPLE_NODES
#include "pgxc/pgFdwRemote.h"
#include "pgxc/globalStatistic.h"
@ -6649,6 +6650,13 @@ void standard_ProcessUtility(Node* parse_tree, const char* query_string, ParamLi
(void)process_column_settings((CreateClientLogicColumn *)parse_tree);
}
break;
case T_CreateModelStmt:{ // DB4AI
exec_create_model((CreateModelStmt*) parse_tree, query_string, params, completion_tag);
break;
}
default: {
ereport(ERROR,
(errcode(ERRCODE_UNRECOGNIZED_NODE_TYPE),
@ -7729,6 +7737,9 @@ const char* CreateCommandTag(Node* parse_tree)
case OBJECT_CONVERSION:
tag = "DROP CONVERSION";
break;
case OBJECT_DB4AI_MODEL:
tag = "DROP MODEL";
break;
case OBJECT_SCHEMA:
tag = "DROP SCHEMA";
break;
@ -8390,6 +8401,9 @@ const char* CreateCommandTag(Node* parse_tree)
case T_ShutdownStmt:
tag = "SHUTDOWN";
break;
case T_CreateModelStmt:
tag = "CREATE MODEL";
break;
default:
elog(WARNING, "unrecognized node type: %d", (int)nodeTag(parse_tree));
@ -9136,6 +9150,9 @@ LogStmtLevel GetCommandLogLevel(Node* parse_tree)
case T_ShutdownStmt:
lev = LOGSTMT_ALL;
break;
case T_CreateModelStmt: // DB4AI
lev = LOGSTMT_ALL;
break;
default:
elog(WARNING, "unrecognized node type: %d", (int)nodeTag(parse_tree));

View File

@ -47,7 +47,7 @@ OBJS = execAmi.o execCurrent.o execGrouping.o execJunk.o execMain.o \
nodeGroup.o nodeSubplan.o nodeSubqueryscan.o nodeTidscan.o \
nodeForeignscan.o nodeWindowAgg.o tstoreReceiver.o spi.o \
nodePartIterator.o nodeStub.o execClusterResize.o lightProxy.o execMerge.o \
nodeExtensible.o opfusion.o opfusion_scan.o opfusion_util.o route.o
nodeExtensible.o opfusion.o opfusion_scan.o opfusion_util.o route.o nodeGD.o nodeKMeans.o
override CPPFLAGS += -D__STDC_FORMAT_MACROS

View File

@ -162,6 +162,8 @@
#include "securec.h"
#include "gstrace/gstrace_infra.h"
#include "gstrace/executer_gstrace.h"
#include "executor/nodeGD.h"
#include "executor/nodeKMeans.h"
#define NODENAMELEN 64
@ -391,6 +393,10 @@ PlanState* ExecInitNodeByType(Plan* node, EState* estate, int eflags)
return (PlanState*)ExecInitVecMergeJoin((VecMergeJoin*)node, estate, eflags);
case T_VecWindowAgg:
return (PlanState*)ExecInitVecWindowAgg((VecWindowAgg*)node, estate, eflags);
case T_GradientDescent:
return (PlanState*)ExecInitGradientDescent((GradientDescent*)node, estate, eflags);
case T_KMeans:
return (PlanState*)ExecInitKMeans((KMeans*)node, estate, eflags);
default:
ereport(ERROR,
(errmodule(MOD_EXECUTOR),
@ -692,7 +698,10 @@ TupleTableSlot* ExecProcNodeByType(PlanState* node)
result = ExecStream((StreamState*)node);
t_thrd.pgxc_cxt.GlobalNetInstr = NULL;
return result;
case T_GradientDescentState:
return ExecGradientDescent((GradientDescentState*)node);
case T_KMeansState:
return ExecKMeans((KMeansState*)node);
default:
ereport(ERROR,
(errmodule(MOD_EXECUTOR),
@ -1343,6 +1352,14 @@ static void ExecEndNodeByType(PlanState* node)
ExecEndVecWindowAgg((VecWindowAggState*)node);
break;
case T_GradientDescentState:
ExecEndGradientDescent((GradientDescentState*)node);
break;
case T_KMeansState:
ExecEndKMeans((KMeansState*)node);
break;
default:
ereport(ERROR,
(errmodule(MOD_EXECUTOR),

View File

@ -69,6 +69,7 @@
#include "gstrace/gstrace_infra.h"
#include "gstrace/executer_gstrace.h"
#include "commands/trigger.h"
#include "db4ai/gd.h"
/* static function decls */
static Datum ExecEvalArrayRef(ArrayRefExprState* astate, ExprContext* econtext, bool* isNull, ExprDoneCond* isDone);
@ -147,6 +148,8 @@ static Datum ExecEvalGroupingIdExpr(
GroupingIdExprState* gstate, ExprContext* econtext, bool* isNull, ExprDoneCond* isDone);
static bool func_has_refcursor_args(Oid Funcid, FunctionCallInfoData* fcinfo);
extern Datum ExecEvalGradientDescent(GradientDescentExprState* mlstate, ExprContext* econtext, bool* isNull, ExprDoneCond* isDone);
THR_LOCAL PLpgSQL_execstate* plpgsql_estate = NULL;
/* ----------------------------------------------------------------
@ -5412,6 +5415,20 @@ ExprState* ExecInitExpr(Expr* node, PlanState* parent)
state = (ExprState*)rnstate;
state->evalfunc = (ExprStateEvalFunc)ExecEvalRownum;
} break;
case T_GradientDescentExpr: {
GradientDescentExprState* ml_state = (GradientDescentExprState*)makeNode(GradientDescentExprState);
ml_state->ps = parent;
ml_state->xpr = (GradientDescentExpr*)node;
state = (ExprState*)ml_state;
if (IsA(parent, GradientDescentState)) {
state->evalfunc = (ExprStateEvalFunc)ExecEvalGradientDescent;
} else {
ereport(ERROR,
(errmodule(MOD_DB4AI),
errcode(ERRCODE_INVALID_OPERATION),
errmsg("unrecognized state %d for GradientDescentExpr", parent->type)));
}
} break;
default:
ereport(ERROR,
(errcode(ERRCODE_UNRECOGNIZED_NODE_TYPE),

View File

@ -0,0 +1,457 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* nodeGD.cpp
*
* IDENTIFICATION
* src/gausskernel/runtime/executor/nodeGD.cpp
*
* ---------------------------------------------------------------------------------------
*/
#include "postgres.h"
#include "executor/executor.h"
#include "executor/nodeGD.h"
#include "db4ai/gd.h"
//////////////////////////////////////////////////////////////////////////
GradientDescentHook_iteration gdhook_iteration = nullptr;
static bool transfer_slot(GradientDescentState* gd_state, TupleTableSlot* slot,
int ith_tuple, Matrix* features, Matrix* dep_var)
{
const GradientDescent* gd_node = gd_get_node(gd_state);
Assert(ith_tuple < (int)features->rows);
if (!slot->tts_isnull[gd_node->targetcol]) {
int feature = 0;
gd_float* w = features->data + ith_tuple * features->columns;
for (int i = 0; i < get_natts(gd_state); i++) {
if (i == gd_node->targetcol && !dep_var_is_continuous(gd_state->algorithm)) {
if (!dep_var_is_binary(gd_state->algorithm)) {
ereport(ERROR,
(errmodule(MOD_DB4AI),
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("categorical dependent variable not implemented")));
}
float dt = 0;
if (get_atttypid(gd_state, gd_node->targetcol) == BOOLOID)
dt = DatumGetBool(slot->tts_values[gd_node->targetcol])
? gd_state->algorithm->max_class : gd_state->algorithm->min_class;
else {
bool found = false;
for (int v = 0; v < gd_state->num_classes && !found; v++) {
found = datumIsEqual(slot->tts_values[gd_node->targetcol], gd_state->binary_classes[v],
get_attbyval(gd_state, gd_node->targetcol),
get_attlen(gd_state, gd_node->targetcol));
if (found)
dt = (v == 1 ? gd_state->algorithm->max_class : gd_state->algorithm->min_class);
}
if (!found) {
if (gd_state->num_classes == 2)
ereport(ERROR,
(errmodule(MOD_DB4AI),
errcode(ERRCODE_TOO_MANY_ARGUMENTS),
errmsg("too many target values for binary operator")));
gd_state->binary_classes[gd_state->num_classes++] =
datumCopy(slot->tts_values[gd_node->targetcol],
get_attbyval(gd_state, gd_node->targetcol),
get_attlen(gd_state, gd_node->targetcol));
}
}
dep_var->data[ith_tuple] = dt;
} else {
gd_float value;
if (slot->tts_isnull[i]) {
Assert(i != gd_node->targetcol);
value = 0.0; // default value for feature, it is not the target for sure
}
else
value = gd_datum_get_float(get_atttypid(gd_state, i), slot->tts_values[i]);
if (i == gd_node->targetcol)
dep_var->data[ith_tuple] = value;
else {
*w++ = value;
feature++;
}
}
}
Assert(feature == gd_state->n_features-1);
*w = 1.0; // bias
return true;
}
return false;
}
void exec_gd_batch(GradientDescentState* gd_state, int iter)
{
// get information from the node
const GradientDescent* gd_node = gd_get_node(gd_state);
PlanState* outer_plan = outerPlanState(gd_state);
Matrix* features;
Matrix* dep_var;
bool more = true;
TupleTableSlot* slot = NULL;
do {
// read next batch
features = gd_state->shuffle->get(gd_state->shuffle, &dep_var);
int ith_tuple = 0;
while (more && ith_tuple < gd_node->batch_size) {
slot = ExecProcNode(outer_plan);
if (TupIsNull(slot)) {
more = false;
} else {
if (transfer_slot(gd_state, slot, ith_tuple, features, dep_var)) {
if (iter == 0)
gd_state->processed++;
ith_tuple++;
} else {
if (iter == 0)
gd_state->discarded++;
}
}
}
// use the batch to test now in case the shuffle algorithm
// releases it during unget
if (iter > 0 && ith_tuple > 0) {
if (ith_tuple < gd_node->batch_size) {
matrix_resize(features, ith_tuple, gd_state->n_features);
matrix_resize(dep_var, ith_tuple, 1);
}
double loss = gd_state->algorithm->test_callback(gd_node, features, dep_var, &gd_state->weights, &gd_state->scores);
gd_state->loss += loss;
ereport(DEBUG1,
(errmodule(MOD_DB4AI),
errmsg("iteration %d loss = %.6f (total %.6g)", iter, loss, gd_state->loss)));
if (ith_tuple < gd_node->batch_size) {
matrix_resize(features, gd_node->batch_size, gd_state->n_features);
matrix_resize(dep_var, gd_node->batch_size, 1);
}
}
// give back the batch to the shuffle algorithm
gd_state->shuffle->unget(gd_state->shuffle, ith_tuple);
} while (more);
}
void exec_gd_start_iteration(GradientDescentState* gd_state)
{
if (gd_state->optimizer->start_iteration != nullptr)
gd_state->optimizer->start_iteration(gd_state->optimizer);
if (gd_state->shuffle->start_iteration != nullptr)
gd_state->shuffle->start_iteration(gd_state->shuffle);
}
void exec_gd_end_iteration(GradientDescentState* gd_state)
{
if (gd_state->shuffle->end_iteration != nullptr)
gd_state->shuffle->end_iteration(gd_state->shuffle);
if (gd_state->optimizer->end_iteration != nullptr)
gd_state->optimizer->end_iteration(gd_state->optimizer);
}
/* ----------------------------------------------------------------
* ExecGradientDescent
* ----------------------------------------------------------------
*
* Training and test are interleaved to avoid a double scan over the data
* for training and test. Iteration 0 only computes the initial weights, and
* at each following iteration the model is tested with the current weights
* and new weights are updated with the gradients. The optimization is clear:
* for N iterations, the basic algorithm requires N*2 data scans, while the
* interleaved train&test requires only N+1 data scans. When N=1 the number
* of scans is the same (N*2 = N+1)
*/
TupleTableSlot* ExecGradientDescent(GradientDescentState* gd_state)
{
// check if training is already finished
if (gd_state->done)
return NULL;
// get information from the node
const GradientDescent* gd_node = gd_get_node(gd_state);
ScanDirection direction = gd_state->ss.ps.state->es_direction;
PlanState* outer_plan = outerPlanState(gd_state);
// If backwards scan, just return NULL without changing state.
if (!ScanDirectionIsForward(direction))
return NULL;
// for counting execution time
uint64_t start, finish, step;
uint64_t iter_start, iter_finish;
// iterations
double prev_loss = 0;
TupleTableSlot* slot = NULL;
gd_state->processed = 0;
gd_state->discarded = 0;
uint64_t max_usecs = ULLONG_MAX;
if (gd_node->max_seconds > 0)
max_usecs = gd_node->max_seconds * 1000000ULL;
bool stop = false;
start = gd_get_clock_usecs();
step = start;
for (int iter = 0; !stop && iter <= gd_node->max_iterations; iter++) {
iter_start = gd_get_clock_usecs();
// init loss & scores
scores_init(&gd_state->scores);
gd_state->loss = 0;
exec_gd_start_iteration(gd_state);
exec_gd_batch(gd_state, iter);
exec_gd_end_iteration(gd_state);
iter_finish = gd_get_clock_usecs();
// delta loss < loss tolerance?
if (iter > 0)
stop = (fabs(prev_loss - gd_state->loss) < gd_node->tolerance);
if (!stop) {
// continue with another iteration with the new weights
int bytes = sizeof(gd_float) * gd_state->n_features;
int rc = memcpy_s(gd_state->weights.data, gd_state->weights.allocated * sizeof(gd_float),
gd_state->optimizer->weights.data, bytes);
securec_check(rc, "", "");
if (iter > 0) {
gd_state->n_iterations++;
if (gdhook_iteration != nullptr)
gdhook_iteration(gd_state);
}
// timeout || max_iterations
stop = (gd_get_clock_usecs()-start >= max_usecs)
|| (iter == gd_node->max_iterations);
}
// trace at end or no more than once per second
bool trace_iteration = gd_node->verbose || stop;
if (!trace_iteration) {
uint64_t now = gd_get_clock_usecs();
uint64_t nusecs = now - step;
if (nusecs > 1000000) {
// more than one second
trace_iteration = true;
step = now;
}
}
if (iter>0 && trace_iteration) {
gd_float* w = gd_state->weights.data;
StringInfoData buf;
initStringInfo(&buf);
for (int i=0 ; i<gd_state->n_features ; i++)
appendStringInfo(&buf, "%.3f,", w[i]);
ereport(DEBUG1,
(errmodule(MOD_DB4AI),
errmsg("ITERATION %d: test_loss=%.6f delta_loss=%.6f tolerance=%.3f accuracy=%.3f tuples=%d coef=%s",
iter, gd_state->loss,
fabs(prev_loss - gd_state->loss), gd_node->tolerance,
get_accuracy(&gd_state->scores), gd_state->processed,
buf.data)));
pfree(buf.data);
}
prev_loss = gd_state->loss;
if (!stop)
ExecReScan(outer_plan); // for the next iteration
}
finish = gd_get_clock_usecs();
gd_state->done = true;
gd_state->usecs = finish - start;
// return trainined model
ExprDoneCond isDone;
slot = ExecProject(gd_state->ss.ps.ps_ProjInfo, &isDone);
return slot;
}
/* ----------------------------------------------------------------
* ExecInitGradientDescent
*
* This initializes the GradientDescent node state structures and
* the node's subplan.
* ----------------------------------------------------------------
*/
GradientDescentState* ExecInitGradientDescent(GradientDescent* gd_node, EState* estate, int eflags)
{
GradientDescentState* gd_state = NULL;
Plan* outer_plan = outerPlan(gd_node);
// check for unsupported flags
Assert(!(eflags & (EXEC_FLAG_REWIND | EXEC_FLAG_BACKWARD | EXEC_FLAG_MARK)));
// create state structure
gd_state = makeNode(GradientDescentState);
gd_state->ss.ps.plan = (Plan*)gd_node;
gd_state->ss.ps.state = estate;
// Tuple table initialization
ExecInitScanTupleSlot(estate, &gd_state->ss);
ExecInitResultTupleSlot(estate, &gd_state->ss.ps);
// initialize child expressions
ExecAssignExprContext(estate, &gd_state->ss.ps);
gd_state->ss.ps.targetlist = (List*)ExecInitExpr((Expr*)gd_node->plan.targetlist, (PlanState*)gd_state);
// initialize outer plan
outerPlanState(gd_state) = ExecInitNode(outer_plan, estate, eflags);
// Initialize result tuple type and projection info.
ExecAssignScanTypeFromOuterPlan(&gd_state->ss); // input tuples
ExecAssignResultTypeFromTL(&gd_state->ss.ps); // result tuple
ExecAssignProjectionInfo(&gd_state->ss.ps, NULL);
gd_state->ss.ps.ps_TupFromTlist = false;
// select algorithm
gd_state->algorithm = gd_get_algorithm(gd_node->algorithm);
// Input tuple initialization
gd_state->tupdesc = ExecGetResultType(outerPlanState(gd_state));
int natts = gd_state->tupdesc->natts;
gd_state->n_features = natts; // -1 dep_var, +1 bias (fixed as 1)
for (int i = 0; i < natts; i++) {
Oid oidtype = gd_state->tupdesc->attrs[i]->atttypid;
if (i == gd_node->targetcol) {
switch (oidtype) {
case BITOID:
case VARBITOID:
case BYTEAOID:
case CHAROID:
case RAWOID:
case NAMEOID:
case TEXTOID:
case BPCHAROID:
case VARCHAROID:
case NVARCHAR2OID:
case CSTRINGOID:
case INT1OID:
case INT2OID:
case INT4OID:
case INT8OID:
case FLOAT4OID:
case FLOAT8OID:
case NUMERICOID:
case ABSTIMEOID:
case DATEOID:
case TIMEOID:
case TIMESTAMPOID:
case TIMESTAMPTZOID:
case TIMETZOID:
case SMALLDATETIMEOID:
// detect the different values while reading the data
gd_state->num_classes = 0;
break;
case BOOLOID:
// values are known in advance
gd_state->binary_classes[0] = BoolGetDatum(false);
gd_state->binary_classes[1] = BoolGetDatum(true);
gd_state->num_classes = 2;
break;
default:
// unsupported datatypes
ereport(ERROR,
(errmodule(MOD_DB4AI),
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Datatype of target not supported")));
break;
}
}
}
// optimizer
switch (gd_node->optimizer) {
case OPTIMIZER_GD:
gd_state->optimizer = gd_init_optimizer_gd(gd_state);
break;
case OPTIMIZER_NGD:
gd_state->optimizer = gd_init_optimizer_ngd(gd_state);
break;
default:
ereport(ERROR,
(errmodule(MOD_DB4AI),
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("Optimizer %d not supported", gd_node->optimizer)));
break;
}
matrix_init(&gd_state->optimizer->weights, gd_state->n_features);
matrix_init(&gd_state->optimizer->gradients, gd_state->n_features);
// shuffle
gd_state->shuffle = gd_init_shuffle_cache(gd_state);
gd_state->shuffle->optimizer = gd_state->optimizer;
// training state initialization
gd_state->done = false;
gd_state->learning_rate = gd_node->learning_rate;
gd_state->n_iterations = 0;
gd_state->loss = 0;
matrix_init(&gd_state->weights, gd_state->n_features);
return gd_state;
}
/* ----------------------------------------------------------------
* ExecEndGradientDescent
*
* This shuts down the subplan and frees resources allocated
* to this node.
* ----------------------------------------------------------------
*/
void ExecEndGradientDescent(GradientDescentState* gd_state)
{
// release state
matrix_release(&gd_state->weights);
gd_state->shuffle->release(gd_state->shuffle);
matrix_release(&gd_state->optimizer->gradients);
matrix_release(&gd_state->optimizer->weights);
gd_state->optimizer->release(gd_state->optimizer);
ExecFreeExprContext(&gd_state->ss.ps);
ExecEndNode(outerPlanState(gd_state));
pfree(gd_state);
}

File diff suppressed because it is too large Load Diff

View File

@ -60,6 +60,6 @@
#endif
#define NAILED_IN_CATALOG_NUM 8
#define CATALOG_NUM 91
#define CATALOG_NUM 92
#endif

View File

@ -204,6 +204,7 @@ typedef enum ObjectClass {
OCLASS_DIRECTORY, /* pg_directory */
OCLASS_PG_JOB, /* pg_job */
OCLASS_RLSPOLICY, /* pg_rlspolicy */
OCLASS_DB4AI_MODEL, /* gs_model_warehouse */
MAX_OCLASS /* MUST BE LAST */
} ObjectClass;

View File

@ -0,0 +1,144 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* -------------------------------------------------------------------------
*
* gs_model.h
*
* IDENTIFICATION
* src/include/catalog/gs_model.h
*
* -------------------------------------------------------------------------
*/
#ifndef GS_MODEL_H
#define GS_MODEL_H
#include "access/genam.h"
#include "access/heapam.h"
#include "access/sysattr.h"
#include "catalog/genbki.h"
#include "catalog/indexing.h"
#include "utils/fmgroids.h"
#include "utils/snapmgr.h"
#include "utils/syscache.h"
#include "utils/timestamp.h"
#include "utils/date.h"
#define ModelRelationId 3991
#define ModelRelation_Rowtype_Id 3994
#ifdef HAVE_INT64_TIMESTAMP
#define timestamp int64
#else
#define timestamp double
#endif
CATALOG(gs_model_warehouse,3991) BKI_ROWTYPE_OID(3994) BKI_SCHEMA_MACRO
{
NameData modelname; /* model name */
Oid modelowner; /* model owner */
timestamp createtime; /* Model storage time */
int4 processedtuples;
int4 discardedtuples;
float4 pre_process_time;
float4 exec_time;
int4 iterations;
Oid outputtype;
#ifdef CATALOG_VARLEN /* variable-length fields start here */
text modeltype; /* svm、kmeans、invalid */
text query;
bytea modeldata; /* OPTION just for ModelBinary*/
float4 weight[1];
text hyperparametersnames[1];
text hyperparametersvalues[1];
Oid hyperparametersoids[1];
text coefnames[1];
text coefvalues[1];
Oid coefoids[1];
text trainingscoresname[1];
float4 trainingscoresvalue[1];
text modeldescribe[1];
#endif /* CATALOG_VARLEN */
} FormData_gs_model_warehouse;
typedef FormData_gs_model_warehouse *Form_gs_model_warehouse;
#define Natts_gs_model_warehouse 22
#define Anum_gs_model_model_name 1
#define Anum_gs_model_owner_oid 2
#define Anum_gs_model_create_time 3
#define Anum_gs_model_processedTuples 4
#define Anum_gs_model_discardedTuples 5
#define Anum_gs_model_process_time_secs 6
#define Anum_gs_model_exec_time_secs 7
#define Anum_gs_model_iterations 8
#define Anum_gs_model_outputType 9
#define Anum_gs_model_model_type 10
#define Anum_gs_model_query 11
#define Anum_gs_model_modelData 12
#define Anum_gs_model_weight 13
#define Anum_gs_model_hyperparametersNames 14
#define Anum_gs_model_hyperparametersValues 15
#define Anum_gs_model_hyperparametersOids 16
#define Anum_gs_model_coefNames 17
#define Anum_gs_model_coefValues 18
#define Anum_gs_model_coefOids 19
#define Anum_gs_model_trainingScoresName 20
#define Anum_gs_model_trainingScoresValue 21
#define Anum_gs_model_modeldescribe 22
// Locate the oid for a given model name
inline Oid get_model_oid(const char* model_name, bool missing_ok)
{
Oid oid;
oid = GetSysCacheOid1(DB4AI_MODEL, CStringGetDatum(model_name));
if (!OidIsValid(oid) && !missing_ok) {
ereport(
ERROR, (errcode(ERRCODE_UNDEFINED_OBJECT), errmsg("model \"%s\" does not exist", model_name)));
}
return oid;
}
// Remove a model by oid
inline void remove_model_by_oid(Oid model_oid)
{
Relation rel;
HeapTuple tup;
ScanKeyData skey[1];
SysScanDesc scan;
ScanKeyInit(&skey[0], ObjectIdAttributeNumber, BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(model_oid));
rel = heap_open(ModelRelationId, RowExclusiveLock);
scan = systable_beginscan(rel, GsModelOidIndexId, true, SnapshotNow, 1, skey);
/* we expect exactly one match */
tup = systable_getnext(scan);
if (!HeapTupleIsValid(tup))
ereport(
ERROR, (errcode(ERRCODE_UNDEFINED_OBJECT), errmsg("could not find tuple for model entry %u", model_oid)));
simple_heap_delete(rel, &tup->t_self);
systable_endscan(scan);
heap_close(rel, RowExclusiveLock);
}
#endif

View File

@ -547,6 +547,11 @@ DECLARE_UNIQUE_INDEX(gs_matviewdep_oid_index, 9992, on gs_matview_dependency usi
DECLARE_UNIQUE_INDEX(gs_opt_model_name_index, 9997, on gs_opt_model using btree(model_name name_ops));
#define GsOPTModelNameIndexId 9997
DECLARE_UNIQUE_INDEX(gs_model_oid_index, 3992, on gs_model_warehouse using btree(oid oid_ops));
#define GsModelOidIndexId 3992
DECLARE_UNIQUE_INDEX(gs_model_name_index, 3993, on gs_model_warehouse using btree(modelname name_ops));
#define GsModelNameIndexId 3993
/* last step of initialization script: build the indexes declared above */
BUILD_INDICES

View File

@ -98,6 +98,9 @@ DATA(insert OID = 4989 ( "snapshot" PGUID 0 _null_ n));
DESCR("snapshot schema");
#define PG_SNAPSHOT_NAMESPACE 4989
DATA(insert OID = 4991 ( "db4ai" PGUID 0 _null_ n));
DESCR("db4ai schema");
#define PG_DB4AI_NAMESPACE 4991
/*
* prototypes for functions in pg_namespace.c
*/

View File

@ -398,6 +398,13 @@ typedef FormData_pg_proc *Form_pg_proc;
#define PGCHECKAUTHIDFUNCOID 3228
#define JSONAGGFUNCOID 5206
#define JSONOBJECTAGGFUNCOID 5209
#define DB4AI_PREDICT_BY_BOOL_OID 7101
#define DB4AI_PREDICT_BY_INT32_OID 7102
#define DB4AI_PREDICT_BY_INT64_OID 7103
#define DB4AI_PREDICT_BY_FLOAT4_OID 7105
#define DB4AI_PREDICT_BY_FLOAT8_OID 7106
#define DB4AI_PREDICT_BY_NUMERIC_OID 7107
#define DB4AI_PREDICT_BY_TEXT_OID 7108
/*
* Symbolic values for prokind column

View File

@ -24,7 +24,7 @@
#define PgxcNodeRelationId 9015
#define PgxcNodeRelation_Rowtype_Id 11649
CATALOG(pgxc_node,9015) BKI_SHARED_RELATION BKI_SCHEMA_MACRO
CATALOG(pgxc_node,9015) BKI_ROWTYPE_OID(11649) BKI_SHARED_RELATION BKI_SCHEMA_MACRO
{
NameData node_name;

View File

@ -56,6 +56,7 @@ DECLARE_TOAST(pg_trigger, 2336, 2337);
DECLARE_TOAST(pg_partition, 5502, 5503);
DECLARE_TOAST(pgxc_class, 5506, 5507);
DECLARE_TOAST(pg_hashbucket, 4390, 4391);
DECLARE_TOAST(gs_model_warehouse, 3995, 3996);
/* shared catalogs */
DECLARE_TOAST(pg_shdescription, 2846, 2847);

View File

@ -0,0 +1,65 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* plannodes.h
*
* IDENTIFICATION
* src/include/dbmind/db4ai/nodes/aifuncs.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef DB4AI_AIFUNCS_H
#define DB4AI_AIFUNCS_H
#include "nodes/plannodes.h"
inline const char* algorithm_ml_to_string(AlgorithmML x)
{
switch(x) {
case LOGISTIC_REGRESSION: return "logistic_regression";
case SVM_CLASSIFICATION: return "svm_classification";
case LINEAR_REGRESSION: return "linear_regression";
case KMEANS: return "kmeans";
case INVALID_ALGORITHM_ML:
default: return "INVALID_ALGORITHM_ML";
}
}
inline AlgorithmML get_algorithm_ml(const char *str)
{
if (0 == strcmp(str, "logistic_regression")) {
return LOGISTIC_REGRESSION;
} else if (0 == strcmp(str, "svm_classification")) {
return SVM_CLASSIFICATION;
} else if (0 == strcmp(str, "linear_regression")) {
return LINEAR_REGRESSION;
} else if (0 == strcmp(str, "kmeans")) {
return KMEANS;
} else {
return INVALID_ALGORITHM_ML;
}
}
inline bool is_supervised(AlgorithmML algorithm)
{
if (algorithm == KMEANS) {
return false;
} else {
return true;
}
}
#endif // DB4AI_AIFUNCS_H

View File

@ -0,0 +1,57 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ---------------------------------------------------------------------------------------
*
* command.h
*
* IDENTIFICATION
* src/include/dbmind/db4ai/commands/create_model.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef CREATE_MODEL_H
#define CREATE_MODEL_H
#include "postgres.h"
#include "nodes/params.h"
#include "nodes/parsenodes.h"
#include "nodes/parsenodes_common.h"
#include "nodes/plannodes.h"
#include "tcop/dest.h"
struct Model;
struct DestReceiverTrainModel {
DestReceiver dest;
Model *model;
List *targetlist; // for gradient descent
};
void configure_dest_receiver_train_model(DestReceiverTrainModel *dest, AlgorithmML algorithm, const char *model_name,
const char *sql);
// Create a DestReceiver object for training model operators
DestReceiver *CreateTrainModelDestReceiver();
// Rewrite a create model query, and plan the query. This method is used in query execution
// and for explain statements
PlannedStmt *plan_create_model(CreateModelStmt *stmt, const char *query_string, ParamListInfo params,
DestReceiver *dest);
// Call executor
void exec_create_model(CreateModelStmt *stmt, const char *queryString, ParamListInfo params, char *completionTag);
#endif

View File

@ -0,0 +1,58 @@
/**
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
---------------------------------------------------------------------------------------
mrc_cpu.h
Code to provide micro-optimizations
IDENTIFICATION
src/include/dbmind/db4ai/executor/mrc_cpu.h
---------------------------------------------------------------------------------------
**/
#ifndef DB4AI_MRC_CPU_H
#define DB4AI_MRC_CPU_H
#if defined(__GNUC__)
#ifndef likely
#define likely(x) __builtin_expect((x) != 0, 1)
#endif
#ifndef unlikely
#define unlikely(x) __builtin_expect((x) != 0, 0)
#endif
#ifndef force_alignment
#define force_alignment(x) __attribute__((aligned((x))))
#endif
#ifndef force_inline
#define force_inline inline __attribute__((always_inline))
#endif
#ifndef prefetch
#define prefetch(address, rw, locality) (__builtin_prefetch(address, rw, locality))
#endif
#else
#define likely(x) (x)
#define unlikely(x) (x)
#define prefetch(address, rw, locality) ()
#define force_alignment(x) ()
#define force_inline ()
#endif
#endif //DB4AI_MRC_CPU_H

View File

@ -0,0 +1,50 @@
/**
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
---------------------------------------------------------------------------------------
distance_functions.h
Current set of distance functions that can be used (for k-means for example)
IDENTIFICATION
src/include/dbmind/db4ai/executor/distance_functions.h
---------------------------------------------------------------------------------------
**/
#ifndef DB4AI_DISTANCE_FUNCTIONS_H
#define DB4AI_DISTANCE_FUNCTIONS_H
#include <cinttypes>
/*
* L1 distance (Manhattan)
*/
extern double l1(double const* p, double const* q, uint32_t dimension);
/*
* Squared Euclidean (default)
*/
extern double l2_squared(double const* p, double const* q, uint32_t dimension);
/*
* L2 distance (Euclidean)
*/
extern double l2(double const* p, double const* q, uint32_t dimension);
/*
* L infinity distance (Chebyshev)
*/
extern double linf(double const* p, double const* q, uint32_t dimension);
#endif //DB4AI_DISTANCE_FUNCTIONS_H

View File

@ -0,0 +1,53 @@
/**
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
---------------------------------------------------------------------------------------
fp_ops.h
Robust floating point operations
IDENTIFICATION
src/include/dbmind/db4ai/executor/fp_ops.h
---------------------------------------------------------------------------------------
**/
#ifndef DB4AI_FP_OPS_H
#define DB4AI_FP_OPS_H
/*
* High precision sum: a + b = *sum + *e
*/
extern void twoSum(double a, double b, double* sum, double* e);
/*
* The equivalent subtraction a - b = *sub + *e
*/
extern void twoDiff(double a, double b, double* sub, double* e);
/*
* High precision product a * b = *mult + *e
*/
extern void twoMult(double a, double b, double* mult, double* e);
/*
* High precision square a * a = *square + *e (faster than twoMult(a, a,..))
*/
extern void square(double a, double* square, double* e);
/*
* High precision division a / b = *div + *e
*/
extern void twoDiv(double a, double b, double* div, double* e);
#endif //DB4AI_FP_OPS_H

163
src/include/db4ai/gd.h Normal file
View File

@ -0,0 +1,163 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* gd.h
*
* IDENTIFICATION
* src/include/db4ai/gd.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef GD_H
#define GD_H
#include "nodes/execnodes.h"
#include "db4ai/predict_by.h"
// returns the current time in microseconds
inline uint64_t
gd_get_clock_usecs() {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000000ULL + tv.tv_usec;
}
enum {
GD_DEPENDENT_VAR_CONTINUOUS = 0x00000000,
GD_DEPENDENT_VAR_BINARY = 0x00000001,
GD_DEPENDENT_VAR_CATEGORICAL = 0x00000002,
};
enum {
METRIC_ACCURACY = 0x0001, // (tp + tn) / n
METRIC_F1 = 0x0002, // 2 * (precision * recall) / (precision + recall)
METRIC_PRECISION = 0x0004, // tp / (tp + fp)
METRIC_RECALL = 0x0008, // tp / (tp + fn)
METRIC_LOSS = 0x0010, // defined by each algorithm
METRIC_MSE = 0x0020, // sum((y-y')^2)) / n
};
char* gd_get_metric_name(int metric);
struct GradientDescent;
typedef void (*f_gd_gradients)(const GradientDescent* gd_node, const Matrix* features, const Matrix* dep_var,
Matrix* weights, Matrix* gradients);
typedef double (*f_gd_test)(const GradientDescent* gd_node, const Matrix* features, const Matrix* dep_var,
const Matrix* weights, Scores* scores);
typedef gd_float (*f_gd_predict)(const Matrix* features, const Matrix* weights);
typedef struct GradientDescentAlgorithm {
const char* name;
int flags;
int metrics;
// values for binary algorithms
// e.g. (0,1) for logistic regression or (-1,1) for svm classifier
gd_float min_class;
gd_float max_class;
// callbacks for hooks
f_gd_gradients gradients_callback;
f_gd_test test_callback;
f_gd_predict predict_callback;
} GradientDescentAlgorithm;
GradientDescentAlgorithm* gd_get_algorithm(AlgorithmML algorithm);
inline bool dep_var_is_continuous(const GradientDescentAlgorithm* algo) {
return ((algo->flags & (GD_DEPENDENT_VAR_BINARY | GD_DEPENDENT_VAR_CATEGORICAL)) == 0);
}
inline bool dep_var_is_binary(const GradientDescentAlgorithm* algo) {
return ((algo->flags & GD_DEPENDENT_VAR_BINARY) != 0);
}
gd_float gd_datum_get_float(Oid type, Datum datum);
Datum gd_float_get_datum(Oid type, gd_float value);
inline Oid
get_atttypid(GradientDescentState* gd_state, int col) {
return gd_state->tupdesc->attrs[col]->atttypid;
}
inline bool
get_attbyval(GradientDescentState* gd_state, int col) {
return gd_state->tupdesc->attrs[col]->attbyval;
}
inline int
get_attlen(GradientDescentState* gd_state, int col) {
return gd_state->tupdesc->attrs[col]->attlen;
}
inline int
get_atttypmod(GradientDescentState* gd_state, int col) {
return gd_state->tupdesc->attrs[col]->atttypmod;
}
inline int
get_natts(GradientDescentState* gd_state) {
return gd_state->tupdesc->natts;
}
struct Model;
ModelPredictor gd_predict_prepare(const Model* model);
Datum gd_predict(ModelPredictor pred, Datum* values, bool* isnull, Oid* types, int values_size);
inline const GradientDescent* gd_get_node(const GradientDescentState* gd_state) {
return (const GradientDescent*)gd_state->ss.ps.plan;
}
const char* gd_get_expr_name(GradientDescentExprField field);
Datum ExecEvalGradientDescent(GradientDescentExprState* mlstate, ExprContext* econtext,
bool* isNull, ExprDoneCond* isDone);
////////////////////////////////////////////////////////////////////////
// optimizers
typedef struct OptimizerGD {
// shared members, initialized by the caller and not by the specialization
Matrix weights;
Matrix gradients;
// specialized methods
void (*start_iteration)(OptimizerGD* optimizer);
void (*update_batch)(OptimizerGD* optimizer, const Matrix* features, const Matrix* dep_var);
void (*end_iteration)(OptimizerGD* optimizer);
void (*release)(OptimizerGD* optimizer);
} OptimizerGD;
OptimizerGD* gd_init_optimizer_gd(const GradientDescentState* gd_state);
OptimizerGD* gd_init_optimizer_ngd(const GradientDescentState* gd_state);
const char* gd_get_optimizer_name(OptimizerML optimizer);
////////////////////////////////////////////////////////////////////////////
// shuffle
typedef struct ShuffleGD {
OptimizerGD* optimizer;
void (*start_iteration)(ShuffleGD* shuffle);
Matrix* (*get)(ShuffleGD* shuffle, Matrix** dep_var);
void (*unget)(ShuffleGD* shuffle, int tuples);
void (*end_iteration)(ShuffleGD* shuffle);
void (*release)(ShuffleGD* shuffle);
} ShuffleGD;
ShuffleGD* gd_init_shuffle_cache(const GradientDescentState* gd_state);
#endif /* GD_H */

View File

@ -0,0 +1,125 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
*
*
* IDENTIFICATION
* src/include/db4ai/hyperparameter_validation.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef DB4AI_HYPERPARAMETER_VALIDATION_H
#define DB4AI_HYPERPARAMETER_VALIDATION_H
#include "postgres.h"
#include "db4ai/model_warehouse.h"
#include "nodes/pg_list.h"
#include "utils/builtins.h"
#include <float.h>
struct HyperparameterDefinition {
const char* name;
Datum default_value;
Datum min_value;
Datum max_value;
const char** valid_values;
void (* enum_setter)(const char* s, void* enum_addr);
Oid type;
int32_t valid_values_size;
int32_t offset;
bool min_inclusive;
bool max_inclusive;
};
struct HyperparameterValidation {
void* min_value;
bool min_inclusive;
void* max_value;
bool max_inclusive;
const char** valid_values;
int32_t valid_values_size;
};
// changes the final value of a hyperparameter into the model warehouse output
void update_model_hyperparameter(Model* model, const char* name, Oid type, Datum value);
void configure_hyperparameters(AlgorithmML algorithm,
List* hyperparameters, Model* model, void* hyperparameter_struct);
// Set int hyperparameter
void set_hyperparameter_value(const char* name, int* hyperparameter,
Value* value, VariableSetKind kind, int default_value, Model* model,
HyperparameterValidation* validation);
// Set double hyperparameter
void set_hyperparameter_value(const char* name, double* hyperparameter, Value* value,
VariableSetKind kind, double default_value, Model* model,
HyperparameterValidation* validation);
// Set string hyperparameter (no const)
void set_hyperparameter_value(const char* name, char** hyperparameter, Value* value,
VariableSetKind kind, char* default_value, Model* model,
HyperparameterValidation* validation);
// Set boolean hyperparameter
void set_hyperparameter_value(const char* name, bool* hyperparameter,
Value* value, VariableSetKind kind, bool default_value, Model* model,
HyperparameterValidation* validation);
// General purpouse method to set the hyperparameters
// Locate the hyperparameter in the list by name and set it to the selected value
// Return the index in the list of the hyperparameters. If not found return -1
template<typename T>
int set_hyperparameter(const char* name, T* hyperparameter, List* hyperparameters, T default_value, Model* model,
HyperparameterValidation* validation) {
int result = 0;
foreach_cell(it, hyperparameters) {
VariableSetStmt* current = lfirst_node(VariableSetStmt, it);
if (strcmp(current->name, name) == 0) {
if (list_length(current->args) > 1) {
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("Hyperparameter %s cannot be a list", current->name)));
}
Value* value = NULL;
if (current->args != NULL) {
A_Const* aconst = NULL;
aconst = linitial_node(A_Const, current->args);
value = &aconst->val;
}
set_hyperparameter_value(name, hyperparameter, value, current->kind, default_value, model, validation);
return result;
}
result++;
}
// If not set by user, set the default value
set_hyperparameter_value(name, hyperparameter, NULL, VAR_SET_DEFAULT, default_value, model, validation);
return -1;
}
#endif

527
src/include/db4ai/matrix.h Normal file
View File

@ -0,0 +1,527 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* matrix.h
*
* IDENTIFICATION
* src/include/dbmind/db4ai/executor/gd/matrix.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef MATRIX_H
#define MATRIX_H
#include "postgres.h"
#include "lib/stringinfo.h"
#include "db4ai/scores.h"
#include "math.h"
#include "float.h"
#define MATRIX_CACHE 16
typedef float4 gd_float;
typedef struct Matrix {
int rows;
int columns;
bool transposed;
int allocated;
gd_float *data;
gd_float cache[MATRIX_CACHE];
} Matrix;
int matrix_expected_size(int rows, int columns = 1);
// initializes a bidimensional matrix or a vector (#columns = 1) to zeros
void matrix_init(Matrix *matrix, int rows, int columns = 1);
// initializes with a copy of a matrix
void matrix_init_clone(Matrix *matrix, const Matrix *src);
// initializes with a transposes a matrix, it is a virtual operation
void matrix_init_transpose(Matrix *matrix, const Matrix *src);
// releases the memory of a matrix and makes it emmpty
void matrix_release(Matrix *matrix);
// fiils a matrix with zeroes
void matrix_zeroes(Matrix *matrix);
// changes the shape of a matrix
// only number of rows will be done later
void matrix_resize(Matrix *matrix, int rows, int columns);
// multiplies a matrix by a vector, row by row
void matrix_mult_vector(const Matrix *matrix, const Matrix *vector, Matrix *result);
// multiple two matrices by coefficients (a.k.a. Hadamard or Schur product)
void matrix_mult_entrywise(Matrix *m1, const Matrix *m2);
// multiplies a matrix by a scalar, coeeficient by coefficient
void matrix_mult_scalar(Matrix *matrix, gd_float factor);
// divides a matrix by a scalar, coeeficient by coefficient
void matrix_divide(Matrix *matrix, gd_float factor);
// adds a matrix by another matrix, coeeficient by coefficient
void matrix_add(Matrix *m1, const Matrix *m2);
// subtracts a matrix by another matrix, coeeficient by coefficient
void matrix_subtract(Matrix *m1, const Matrix *m2);
// squares all coefficients
void matrix_square(Matrix *matrix);
// obtains the square root of all coefficients
void matrix_square_root(Matrix *matrix);
// computes the sigmoid of all coefficients: 1.0 / (1.0 + exp(-c))
void matrix_sigmoid(Matrix *matrix);
// computes the natural logarithm of all coefficients
void matrix_log(Matrix *matrix);
// computes the natural logarithm log(1+c) of all coefficients
void matrix_log1p(Matrix *matrix);
// negates all coefficients (-c)
void matrix_negate(Matrix *matrix);
// complements all coefficients (1-c)
void matrix_complement(Matrix *matrix);
// make sure all coeeficients are c>=0
void matrix_positive(Matrix *matrix);
// return the sum of all coefficients
gd_float matrix_get_sum(Matrix *matrix);
// scales a matrix row by row, using two vectors ( N & D)
// where each coefficient is scale c'=(c-N)/D
// - normalization: N=min, D=(max-min)
// - standardization: N=avg, D=stdev
void matrix_scale(Matrix *matrix, const Matrix *m_n, const Matrix *m_d);
// computes the dot product of two vectors
gd_float matrix_dot(const Matrix *v1, const Matrix *v2);
// converts all coefficients to binary values w.r.t. a threshold
// low: v<threshold; high: v>=threshold
void matrix_binary(Matrix *matrix, gd_float threshold, gd_float low, gd_float high);
// compares two binary vectors
void matrix_relevance(const Matrix *v1, const Matrix *v2, Scores *scores, gd_float positive);
// prints to a buffer
void matrix_print(const Matrix *matrix, StringInfo buf, bool full);
// prints into the log
void elog_matrix(int elevel, const char *msg, const Matrix *matrix);
// ///////////////////////////////////////////////////////////////////////////
// inline
inline int matrix_expected_size(int rows, int columns)
{
Assert(rows > 0);
Assert(columns > 0);
int cells = rows * columns;
if (cells <= MATRIX_CACHE)
cells = 0; // cached, no extra memory
return sizeof(Matrix) + cells * sizeof(gd_float);
}
inline void matrix_init(Matrix *matrix, int rows, int columns)
{
Assert(matrix != nullptr);
Assert(rows > 0);
Assert(columns > 0);
matrix->transposed = false;
matrix->rows = rows;
matrix->columns = columns;
matrix->allocated = rows * columns;
if (matrix->allocated <= MATRIX_CACHE) {
matrix->data = matrix->cache;
errno_t rc = memset_s(matrix->data, MATRIX_CACHE * sizeof(gd_float), 0, matrix->allocated * sizeof(gd_float));
securec_check(rc, "", "");
} else
matrix->data = (gd_float *)palloc0(matrix->allocated * sizeof(gd_float));
}
inline void matrix_init_clone(Matrix *matrix, const Matrix *src)
{
Assert(matrix != nullptr);
Assert(src != nullptr);
Assert(!src->transposed);
matrix_init(matrix, src->rows, src->columns);
size_t bytes = src->rows * src->columns * sizeof(gd_float);
errno_t rc = memcpy_s(matrix->data, bytes, src->data, bytes);
securec_check(rc, "", "");
}
// fake transpose, only points to the data of the other matrix
inline void matrix_init_transpose(Matrix *matrix, const Matrix *src)
{
Assert(matrix != nullptr);
Assert(src != nullptr);
Assert(!src->transposed);
matrix->transposed = true;
matrix->rows = src->columns;
matrix->columns = src->rows;
matrix->allocated = 0;
matrix->data = src->data;
}
inline void matrix_release(Matrix *matrix)
{
Assert(matrix != nullptr);
if (matrix->allocated > 0) {
Assert(matrix->data != nullptr);
if (matrix->data != matrix->cache)
pfree(matrix->data);
matrix->allocated = 0;
}
matrix->data = nullptr;
matrix->rows = 0;
matrix->columns = 0;
}
inline void matrix_zeroes(Matrix *matrix)
{
Assert(matrix != nullptr);
Assert(!matrix->transposed);
errno_t rc = memset_s(matrix->data, sizeof(gd_float) * matrix->allocated, 0,
sizeof(gd_float) * matrix->rows * matrix->columns);
securec_check(rc, "", "");
}
inline void matrix_resize(Matrix *matrix, int rows, int columns)
{
Assert(matrix != nullptr);
Assert(!matrix->transposed);
Assert(rows > 0);
Assert(columns > 0);
if (columns != matrix->columns)
ereport(ERROR,
(errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED), errmsg("resize column not yet supported")));
if (rows != matrix->rows) {
if (rows > matrix->rows) {
int required = rows * matrix->columns;
if (required > matrix->allocated)
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("matrix growth not yet supported")));
}
matrix->rows = rows;
}
}
inline void matrix_mult_vector(const Matrix *matrix, const Matrix *vector, Matrix *result)
{
Assert(matrix != nullptr);
Assert(vector != nullptr);
Assert(vector->columns == 1);
Assert(!vector->transposed);
Assert(result != nullptr);
Assert(!result->transposed);
Assert(result->columns == 1);
Assert(matrix->rows == result->rows);
Assert(matrix->columns == vector->rows);
if (matrix->transposed) {
gd_float *pd = result->data;
// loop assumes that the data has not been physically transposed
for (int r = 0; r < matrix->rows; r++) {
const gd_float *pm = matrix->data + r;
const gd_float *pv = vector->data;
gd_float x = 0.0;
for (int c = 0; c < matrix->columns; c++) {
x += *pm * *pv++;
pm += matrix->rows;
}
*pd++ = x;
}
} else {
const gd_float *pm = matrix->data;
gd_float *pd = result->data;
for (int r = 0; r < matrix->rows; r++) {
const gd_float *pv = vector->data;
size_t count = matrix->columns;
gd_float x = 0.0;
while (count-- > 0)
x += *pv++ * *pm++;
*pd++ = x;
}
}
}
inline void matrix_mult_entrywise(Matrix *m1, const Matrix *m2)
{
Assert(m1 != nullptr);
Assert(!m1->transposed);
Assert(m2 != nullptr);
Assert(!m2->transposed);
Assert(m1->rows == m2->rows);
Assert(m1->columns == m2->columns);
size_t count = m1->rows * m1->columns;
gd_float *pd = m1->data;
const gd_float *ps = m2->data;
while (count-- > 0)
*pd++ *= *ps++;
}
inline void matrix_mult_scalar(Matrix *matrix, gd_float factor)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0)
*pd++ *= factor;
}
inline void matrix_divide(Matrix *matrix, gd_float factor)
{
Assert(matrix != nullptr);
Assert(factor != 0.0);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0)
*pd++ /= factor;
}
inline void matrix_add(Matrix *m1, const Matrix *m2)
{
Assert(m1 != nullptr);
Assert(!m1->transposed);
Assert(m2 != nullptr);
Assert(!m2->transposed);
Assert(m1->rows == m2->rows);
Assert(m1->columns == m2->columns);
size_t count = m1->rows * m1->columns;
gd_float *p1 = m1->data;
const gd_float *p2 = m2->data;
while (count-- > 0)
*p1++ += *p2++;
}
inline void matrix_subtract(Matrix *m1, const Matrix *m2)
{
Assert(m1 != nullptr);
Assert(!m1->transposed);
Assert(m2 != nullptr);
Assert(!m2->transposed);
Assert(m1->rows == m2->rows);
Assert(m1->columns == m2->columns);
size_t count = m1->rows * m1->columns;
gd_float *p1 = m1->data;
const gd_float *p2 = m2->data;
while (count-- > 0)
*p1++ -= *p2++;
}
inline gd_float matrix_dot(const Matrix *v1, const Matrix *v2)
{
Assert(v1 != nullptr);
Assert(!v1->transposed);
Assert(v2 != nullptr);
Assert(!v2->transposed);
Assert(v1->rows == v2->rows);
Assert(v1->columns == 1);
Assert(v2->columns == 1);
size_t count = v1->rows;
const gd_float *p1 = v1->data;
const gd_float *p2 = v2->data;
gd_float result = 0;
while (count-- > 0)
result += *p1++ * *p2++;
return result;
}
inline void matrix_square(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
*pd *= *pd;
pd++;
}
}
inline void matrix_square_root(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
*pd = sqrt(*pd);
pd++;
}
}
inline void matrix_sigmoid(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
gd_float c = *pd;
*pd++ = 1.0 / (1.0 + exp(-c));
}
}
inline void matrix_log(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
gd_float v = *pd;
*pd++ = log(v);
}
}
inline void matrix_log1p(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
gd_float v = *pd + 1;
*pd++ = log(v);
}
}
inline void matrix_negate(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
gd_float v = *pd;
*pd++ = -v;
}
}
inline void matrix_complement(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
gd_float v = 1.0 - *pd;
*pd++ = v;
}
}
inline void matrix_positive(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
gd_float v = *pd;
if (v < 0.0)
*pd = 0.0;
pd++;
}
}
inline gd_float matrix_get_sum(Matrix *matrix)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *ps = matrix->data;
gd_float s = 0.0;
while (count-- > 0)
s += *ps++;
return s;
}
inline void matrix_scale(Matrix *matrix, const Matrix *m_n, const Matrix *m_d)
{
Assert(matrix != nullptr);
Assert(m_n != nullptr);
Assert(m_d != nullptr);
Assert(!matrix->transposed);
Assert(!m_n->transposed);
Assert(matrix->columns == m_n->rows);
Assert(m_n->columns == 1);
Assert(!m_d->transposed);
Assert(matrix->columns == m_d->rows);
Assert(m_d->columns == 1);
gd_float *pd = matrix->data;
for (int r = 0; r < matrix->rows; r++) {
const gd_float *p1 = m_n->data;
const gd_float *p2 = m_d->data;
for (int c = 0; c < matrix->columns; c++) {
*pd = (*pd - *p1++) / *p2++;
pd++;
}
}
}
inline void matrix_binary(Matrix *matrix, gd_float threshold, gd_float low, gd_float high)
{
Assert(matrix != nullptr);
size_t count = matrix->rows * matrix->columns;
gd_float *pd = matrix->data;
while (count-- > 0) {
gd_float v = *pd;
*pd++ = (v < threshold ? low : high);
}
}
inline void matrix_relevance(const Matrix *v1, const Matrix *v2, Scores *scores, gd_float positive)
{
Assert(v1 != nullptr);
Assert(!v1->transposed);
Assert(v2 != nullptr);
Assert(!v2->transposed);
Assert(v1->rows == v2->rows);
Assert(v1->columns == 1);
Assert(v2->columns == 1);
size_t count = v1->rows;
const gd_float *p1 = v1->data;
const gd_float *p2 = v2->data;
while (count-- > 0) {
gd_float x = *p1++;
gd_float y = *p2++;
if (x == positive) {
// positive
if (y == positive)
scores->tp++;
else
scores->fp++;
} else {
// negative
if (y != positive)
scores->tn++;
else
scores->fn++;
}
scores->count++;
}
}
#endif /* MATRIX_H */

View File

@ -0,0 +1,112 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* command.h
*
* IDENTIFICATION
* src/gausskernel/catalog/model_warehouse.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef DB4AI_MODEL_WAREHOUSE
#define DB4AI_MODEL_WAREHOUSE
#include "postgres.h"
#include "nodes/pg_list.h"
#include "nodes/plannodes.h"
#include "nodes/execnodes.h"
struct Hyperparameter{
const char* name;
Oid type;
Datum value;
};
struct TrainingInfo{
const char* name;
Oid type;
Datum value;
};
struct TrainingScore{
const char* name;
double value;
};
// Base class for models
struct Model{
AlgorithmML algorithm;
const char* model_name;
const char* sql;
double exec_time_secs;
double pre_time_secs;
int64_t processed_tuples;
int64_t discarded_tuples;
List* train_info; // List of TrainingInfo
List* hyperparameters; // List of Hyperparamters
List* scores; // List of TrainingScore
Oid return_type; // Return type of the model
int32_t num_actual_iterations;
};
// Used by all GradientDescent variants
struct ModelGradientDescent{
Model model;
Datum weights; // Float[]
int ncategories; // 0 for continuous
Datum categories; // only for categorical, an array of return_type[ncategories]
};
// Used by K-Means models
typedef struct WHCentroid {
double objective_function = DBL_MAX;
double avg_distance_to_centroid = DBL_MAX;
double min_distance_to_centroid = DBL_MAX;
double max_distance_to_centroid = DBL_MAX;
double std_dev_distance_to_centroid = DBL_MAX;
uint64_t cluster_size = 0ULL;
double* coordinates = nullptr;
uint32_t id = 0U;
} WHCentroid;
struct ModelKMeans {
Model model;
/*
* the following fields are put here for convenience
*/
uint32_t original_num_centroids = 0U;
uint32_t actual_num_centroids = 0U;
uint32_t dimension = 0U;
uint32_t distance_function_id = 0U;
uint64_t seed = 0ULL;
WHCentroid* centroids = nullptr;
};
// Used by XGBoost
struct ModelBinary {
Model model;
uint64_t model_len;
Datum model_data; // varlena (void*)
};
// Store the model in the catalog tables
void store_model(const Model* model);
// Get the model from the catalog tables
Model* get_model(const char* model_name, bool only_model);
#endif

View File

@ -0,0 +1,51 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
*
* IDENTIFICATION
* src/include/db4ai/predict_by.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef DB4AI_PREDICT_BY_H
#define DB4AI_PREDICT_BY_H
#include "postgres.h"
#include "db4ai/model_warehouse.h"
#include "nodes/parsenodes_common.h"
#include "fmgr/fmgr_comp.h"
#include "parser/parse_node.h"
typedef void* ModelPredictor; // Deserialized version of the model that can compute efficiently predictions
struct PredictorInterface {
ModelPredictor (*prepare) (const Model* model);
Datum (*predict) (ModelPredictor pred, Datum* values, bool* isnull, Oid* types, int values_size);
};
Datum db4ai_predict_by_bool(PG_FUNCTION_ARGS);
Datum db4ai_predict_by_int32(PG_FUNCTION_ARGS);
Datum db4ai_predict_by_int64(PG_FUNCTION_ARGS);
Datum db4ai_predict_by_float4(PG_FUNCTION_ARGS);
Datum db4ai_predict_by_float8(PG_FUNCTION_ARGS);
Datum db4ai_predict_by_numeric(PG_FUNCTION_ARGS);
Datum db4ai_predict_by_text(PG_FUNCTION_ARGS);
#endif

View File

@ -0,0 +1,76 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* scores.h
*
* IDENTIFICATION
* src/include/dbmind/db4ai/executor/gd/scores.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef SCORES_H
#define SCORES_H
#include "postgres.h"
typedef struct Scores {
int count;
// for classification or binary regression
int tp; // true positive
int tn; // true negative
int fp; // false positive
int fn; // false negative
// for continuous regression
float mse; // mean squared error
} Scores;
inline void scores_init(Scores *scores)
{
errno_t rc = memset_s(scores, sizeof(Scores), 0, sizeof(Scores));
securec_check(rc, "", "");
}
// (tp + tn) / n
inline double get_accuracy(const Scores *scores)
{
return (scores->tp + scores->tn) / (double)scores->count;
}
// tp / (tp + fp)
inline double get_precision(const Scores *scores, bool *has)
{
double d = scores->tp + scores->fp;
if (d > 0) {
*has = true;
return scores->tp / d;
}
*has = false;
return 0;
}
// tp / (tp + fn)
inline double get_recall(const Scores *scores, bool *has)
{
double d = scores->tp + scores->fn;
if (d > 0) {
*has = true;
return scores->tp / d;
}
*has = false;
return 0;
}
#endif /* SCORES_H */

View File

@ -0,0 +1,38 @@
/*
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
*
* openGauss is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*---------------------------------------------------------------------------------------
*
* nodeGD.h
*
* IDENTIFICATION
* src/include/executor/nodeGD.h
*
* ---------------------------------------------------------------------------------------
*/
#ifndef NODE_GD_H
#define NODE_GD_H
#include "nodes/execnodes.h"
extern GradientDescentState* ExecInitGradientDescent(GradientDescent* node, EState* estate, int eflags);
extern TupleTableSlot* ExecGradientDescent(GradientDescentState* state);
extern void ExecEndGradientDescent(GradientDescentState* state);
typedef void (*GradientDescentHook_iteration)(GradientDescentState* state);
extern GradientDescentHook_iteration gdhook_iteration;
List* makeGradientDescentExpr(AlgorithmML algorithm, List* list, int field);
#endif /* NODE_GD_H */

View File

@ -0,0 +1,42 @@
/**
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
---------------------------------------------------------------------------------------
nodeKMeans.h
Functions related to the k-means operator
IDENTIFICATION
src/include/executor/nodeKMeans.h
---------------------------------------------------------------------------------------
**/
#ifndef DB4AI_NODEKMEANS_H
#define DB4AI_NODEKMEANS_H
#include "nodes/execnodes.h"
/*
* do not touch this unless you know what you're doing (the implications on
* nodeKMeans.cpp and create_model.cpp). you are warned.
*/
uint32_t constexpr NUM_ATTR_OUTPUT = 15U;
extern KMeansState* ExecInitKMeans(KMeans* node, EState* estate, int eflags);
extern TupleTableSlot* ExecKMeans(KMeansState* node);
extern void ExecEndKMeans(KMeansState* node);
#endif //DB4AI_NODEKMEANS_H

View File

@ -21,6 +21,7 @@
#include "nodes/params.h"
#include "nodes/plannodes.h"
#include "storage/pagecompress.h"
#include "utils/array.h"
#include "utils/bloom_filter.h"
#include "utils/reltrigger.h"
#include "utils/sortsupport.h"
@ -28,6 +29,8 @@
#include "utils/tuplestore.h"
#include "vecexecutor/vectorbatch.h"
#include "db4ai/matrix.h"
#ifdef ENABLE_MOT
// forward declaration for MOT JitContext
namespace JitExec
@ -2483,5 +2486,133 @@ inline bool BitmapNodeNeedSwitchPartRel(BitmapHeapScanState* node)
extern RangeScanInRedis reset_scan_qual(Relation currHeapRel, ScanState *node);
/*
* DB4AI GD node: used for train models using Gradient Descent
*/
struct GradientDescentAlgorithm;
struct GradientDescentExpr;
struct OptimizerGD;
struct ShuffleGD;
typedef struct GradientDescentState {
ScanState ss; /* its first field is NodeTag */
GradientDescentAlgorithm* algorithm;
OptimizerGD* optimizer;
ShuffleGD* shuffle;
// tuple description
TupleDesc tupdesc;
int n_features; // number of features
// dependant var binary values
int num_classes;
Datum binary_classes[2];
// training state
bool done; // when finished
Matrix weights;
double learning_rate;
int n_iterations;
int usecs; // execution time
int processed; // tuples
int discarded;
float loss;
Scores scores;
} GradientDescentState;
typedef struct GradientDescentExprState {
ExprState xprstate;
PlanState* ps;
GradientDescentExpr* xpr;
} GradientDescentExprState;
/*
* DB4AI k-means node
*/
/*
* to keep running statistics on each cluster being constructed
*/
class IncrementalStatistics {
uint64_t population = 0;
double max_value = 0.;
double min_value = DBL_MAX;
double total = 0.;
double s = 0;
public:
IncrementalStatistics operator+(IncrementalStatistics const& rhs) const;
IncrementalStatistics operator-(IncrementalStatistics const& rhs) const;
IncrementalStatistics& operator+=(IncrementalStatistics const& rhs);
IncrementalStatistics& operator-=(IncrementalStatistics const& rhs);
double getMin() const;
void setMin(double);
double getMax() const;
void setMax(double);
double getTotal() const;
void setTotal(double);
uint64_t getPopulation() const;
void setPopulation(uint64_t);
double getEmpiricalMean() const;
double getEmpiricalVariance() const;
double getEmpiricalStdDev() const;
bool reset();
};
/*
* internal representation of a centroid
*/
typedef struct Centroid {
IncrementalStatistics statistics;
ArrayType* coordinates = nullptr;
uint32_t id = 0U;
} Centroid;
/*
* current state of the algorithm
*/
typedef struct KMeansStateDescription {
Centroid* centroids[2] = {nullptr};
ArrayType* bbox_min = nullptr;
ArrayType* bbox_max = nullptr;
double (* distance)(double const*, double const*, uint32_t const dimension) = nullptr;
double execution_time = 0.;
double seeding_time = 0.;
IncrementalStatistics solution_statistics[2];
uint64_t num_good_points = 0UL;
uint64_t num_dead_points = 0UL;
uint32_t current_iteration = 0U;
uint32_t current_centroid = 0U;
uint32_t dimension = 0U;
uint32_t num_centroids = 0U;
uint32_t actual_num_iterations = 0U;
} KMeansStateDescription;
/*
* current state of the k-means node
*/
typedef struct KMeansState {
ScanState sst;
KMeansStateDescription description;
bool done = false;
} KMeansState;
/*
* internal representation of a point (not a centroid)
*/
typedef struct GSPoint {
ArrayType* pg_coordinates = nullptr;
uint32_t weight = 1U;
uint32_t id = 0ULL;
double distance_to_closest_centroid = DBL_MAX;
bool should_free = false;
} GSPoint;
#endif /* EXECNODES_H */

View File

@ -725,7 +725,17 @@ typedef enum NodeTag {
T_ClientLogicColumnParam,
T_CreateClientLogicColumn,
T_ClientLogicColumnRef,
T_ExprWithComma
T_ExprWithComma,
// DB4AI
T_CreateModelStmt = 5000,
T_PredictByFunction,
T_GradientDescent,
T_GradientDescentState,
T_GradientDescentExpr,
T_GradientDescentExprState,
T_KMeans,
T_KMeansState,
// End DB4AI
} NodeTag;
/* if you add to NodeTag also need to add nodeTagToString */

View File

@ -55,6 +55,7 @@ typedef enum ObjectType {
OBJECT_CONVERSION,
OBJECT_DATABASE,
OBJECT_DATA_SOURCE,
OBJECT_DB4AI_MODEL, // DB4AI
OBJECT_DOMAIN,
OBJECT_EXTENSION,
OBJECT_FDW,
@ -1776,4 +1777,35 @@ typedef struct RenameStmt {
bool missing_ok; /* skip error if missing? */
} RenameStmt;
/* ----------------------
* Create Model Statement
* ----------------------
*/
typedef struct CreateModelStmt{ // DB4AI
NodeTag type;
char* model;
char* architecture;
List* hyperparameters; // List<VariableSetStmt>
Node* select_query; // Query to be executed: SelectStmt -> Query
List* model_features; // FEATURES clause
List* model_target; // TARGET clause
// Filled during transform
AlgorithmML algorithm; // Algorithm to be executed
} CreateModelStmt;
/* ----------------------
* Prediction BY function
* ----------------------
*/
typedef struct PredictByFunction{ // DB4AI
NodeTag type;
char* model_name;
int model_name_location; // Only for parser
List* model_args;
int model_args_location; // Only for parser
} PredictByFunction;
#endif /* PARSENODES_COMMONH */

View File

@ -1399,5 +1399,111 @@ static inline bool IsJoinPlan(Node* node)
IsA(node, VecMergeJoin) || IsA(node, HashJoin) || IsA(node, VecHashJoin);
}
/*
* DB4AI
*/
// GD optimizers
typedef enum {
OPTIMIZER_GD, // simple mini-batch
OPTIMIZER_NGD, // normalized gradient descent
} OptimizerML;
inline void optimizer_ml_setter(const char* str, void* optimizer_ml){
OptimizerML* optimizer = (OptimizerML*) optimizer_ml;
if (strcmp(str, "gd") == 0)
*optimizer = OPTIMIZER_GD;
else if (strcmp(str, "ngd") == 0)
*optimizer = OPTIMIZER_NGD;
else
elog(ERROR, "Invalid optimizer. Current candidates are: gd (default), ngd");
return;
}
// Gradient Descent node
typedef struct GradientDescent {
Plan plan;
AlgorithmML algorithm;
int targetcol;
// generic hyperparameters
OptimizerML optimizer; // default GD/mini-batch
int max_seconds; // 0 to disable
bool verbose;
int max_iterations; // maximum number of iterations
int batch_size;
double learning_rate;
double decay; // (0:1], learning rate decay
double tolerance; // [0:1], 0 means to run all iterations
int seed; // [0:N], random seed
// for SVM
double lambda; // regularization strength
} GradientDescent;
/*
* DB4AI k-means
*/
/*
* current available distance functions
*/
typedef enum : uint32_t {
KMEANS_L1 = 0U,
KMEANS_L2,
KMEANS_L2_SQUARED,
KMEANS_LINF
} DistanceFunction;
/*
* current available seeding method
*/
typedef enum : uint32_t {
KMEANS_RANDOM_SEED = 0U,
KMEANS_BB
} SeedingFunction;
/*
* Verbosity level
*/
typedef enum : uint32_t {
NO_OUTPUT = 0U,
FASTCHECK_OUTPUT,
VERBOSE_OUTPUT
} Verbosity;
/*
* description of the k-means instance
*/
struct KMeansDescription {
char const* model_name = nullptr;
SeedingFunction seeding = KMEANS_RANDOM_SEED;
DistanceFunction distance = KMEANS_L2_SQUARED;
Verbosity verbosity = NO_OUTPUT;
uint32_t n_features = 0U;
uint32_t batch_size = 0U;
};
/*
* current hyper-parameters
*/
struct KMeansHyperParameters {
uint32_t num_centroids = 0U;
uint32_t num_iterations = 0U;
uint64_t external_seed = 0ULL;
double tolerance = 0.00001;
};
/*
* the actual k-means operator
*/
typedef struct KMeans {
Plan plan;
AlgorithmML algorithm;
KMeansDescription description;
KMeansHyperParameters parameters;
} KMeans;
#endif /* PLANNODES_H */

View File

@ -1401,4 +1401,42 @@ typedef struct UpsertExpr {
bool partKeyUpsert; /* we allow upsert index key and partition key in B_FORMAT */
} UpsertExpr;
/*
* DB4AI
*/
#define DB4AI_SNAPSHOT_VERSION_DELIMITER 1
#define DB4AI_SNAPSHOT_VERSION_SEPARATOR 2
typedef enum {
LOGISTIC_REGRESSION,
SVM_CLASSIFICATION,
KMEANS,
LINEAR_REGRESSION,
INVALID_ALGORITHM_ML,
} AlgorithmML;
typedef enum GradientDescentExprField {
// generic
GD_EXPR_ALGORITHM,
GD_EXPR_OPTIMIZER,
GD_EXPR_RESULT_TYPE,
GD_EXPR_NUM_ITERATIONS,
GD_EXPR_EXEC_TIME_MSECS,
GD_EXPR_PROCESSED_TUPLES,
GD_EXPR_DISCARDED_TUPLES,
GD_EXPR_WEIGHTS,
GD_EXPR_CATEGORIES,
// scores
GD_EXPR_SCORE = 0x10000, // or-ed with the score id
} GradientDescentExprField;
#define makeGradientDescentExprFieldScore(_SCORE) (GradientDescentExprField)((int)GD_EXPR_SCORE | _SCORE)
typedef struct GradientDescentExpr {
Expr xpr;
GradientDescentExprField field;
Oid fieldtype; /* pg_type OID of the datatype */
} GradientDescentExpr;
#endif /* PRIMNODES_H */

View File

@ -226,6 +226,7 @@ PG_KEYWORD("extract", EXTRACT, COL_NAME_KEYWORD)
PG_KEYWORD("false", FALSE_P, RESERVED_KEYWORD)
PG_KEYWORD("family", FAMILY, UNRESERVED_KEYWORD)
PG_KEYWORD("fast", FAST, UNRESERVED_KEYWORD)
PG_KEYWORD("features", FEATURES, UNRESERVED_KEYWORD)
PG_KEYWORD("fenced", FENCED, RESERVED_KEYWORD)
PG_KEYWORD("fetch", FETCH, RESERVED_KEYWORD)
PG_KEYWORD("fileheader", FILEHEADER_P, UNRESERVED_KEYWORD)
@ -348,6 +349,7 @@ PG_KEYWORD("minus", MINUS_P, RESERVED_KEYWORD)
PG_KEYWORD("minute", MINUTE_P, UNRESERVED_KEYWORD)
PG_KEYWORD("minvalue", MINVALUE, UNRESERVED_KEYWORD)
PG_KEYWORD("mode", MODE, UNRESERVED_KEYWORD)
PG_KEYWORD("model", MODEL, UNRESERVED_KEYWORD)
PG_KEYWORD("modify", MODIFY_P, RESERVED_KEYWORD)
PG_KEYWORD("month", MONTH_P, UNRESERVED_KEYWORD)
PG_KEYWORD("move", MOVE, UNRESERVED_KEYWORD)
@ -422,6 +424,7 @@ PG_KEYWORD("pool", POOL, UNRESERVED_KEYWORD)
PG_KEYWORD("position", POSITION, COL_NAME_KEYWORD)
PG_KEYWORD("preceding", PRECEDING, UNRESERVED_KEYWORD)
PG_KEYWORD("precision", PRECISION, COL_NAME_KEYWORD)
PG_KEYWORD("predict", PREDICT, UNRESERVED_KEYWORD)
/* PGXC_BEGIN */
PG_KEYWORD("preferred", PREFERRED, UNRESERVED_KEYWORD)
/* PGXC_END */
@ -441,6 +444,7 @@ PG_KEYWORD("query", QUERY, UNRESERVED_KEYWORD)
PG_KEYWORD("quote", QUOTE, UNRESERVED_KEYWORD)
PG_KEYWORD("randomized", RANDOMIZED, UNRESERVED_KEYWORD)
PG_KEYWORD("range", RANGE, UNRESERVED_KEYWORD)
PG_KEYWORD("ratio", RATIO, UNRESERVED_KEYWORD)
PG_KEYWORD("raw", RAW, UNRESERVED_KEYWORD)
PG_KEYWORD("read", READ, UNRESERVED_KEYWORD)
PG_KEYWORD("real", REAL, COL_NAME_KEYWORD)
@ -484,6 +488,7 @@ PG_KEYWORD("rownum", ROWNUM, RESERVED_KEYWORD)
#endif
PG_KEYWORD("rows", ROWS, UNRESERVED_KEYWORD)
PG_KEYWORD("rule", RULE, UNRESERVED_KEYWORD)
PG_KEYWORD("sample", SAMPLE, UNRESERVED_KEYWORD)
PG_KEYWORD("savepoint", SAVEPOINT, UNRESERVED_KEYWORD)
PG_KEYWORD("schema", SCHEMA, UNRESERVED_KEYWORD)
PG_KEYWORD("scroll", SCROLL, UNRESERVED_KEYWORD)
@ -542,6 +547,7 @@ PG_KEYWORD("table", TABLE, RESERVED_KEYWORD)
PG_KEYWORD("tables", TABLES, UNRESERVED_KEYWORD)
PG_KEYWORD("tablesample", TABLESAMPLE, TYPE_FUNC_NAME_KEYWORD)
PG_KEYWORD("tablespace", TABLESPACE, UNRESERVED_KEYWORD)
PG_KEYWORD("target", TARGET, UNRESERVED_KEYWORD)
PG_KEYWORD("temp", TEMP, UNRESERVED_KEYWORD)
PG_KEYWORD("template", TEMPLATE, UNRESERVED_KEYWORD)
PG_KEYWORD("temporary", TEMPORARY, UNRESERVED_KEYWORD)

View File

@ -108,6 +108,8 @@ typedef enum {
DestBatchLocalRedistribute, /* results send to consumer thread in a local redistribute way */
DestBatchLocalRoundRobin, /* results send to consumer thread in a local roundrobin way */
DestTrainModel, /* results send to DB4AI model warehouse */
DestBatchHybrid,
DestTransientRel /* results sent to transient relation */

View File

@ -102,6 +102,7 @@ enum ModuleId {
MOD_THREAD_POOL, /* thread_pool */
MOD_OPT_AI, /* ai optimizer */
MOD_GEN_COL, /* generated column */
MOD_DB4AI, /* DB4AI & AUTOML */
/* add your module id above */
MOD_MAX

View File

@ -60,6 +60,8 @@ enum SysCacheIdentifier {
DATABASEOID,
DATASOURCENAME,
DATASOURCEOID,
DB4AI_MODELOID,
DB4AI_MODEL,
DEFACLROLENSPOBJ,
DIRECTORYNAME,
DIRECTORYOID,

View File

@ -266,6 +266,9 @@ fastcheck_single: all tablespace-setup
$(call hotpatch_setup_func) && \
$(pg_regress_check) $(REGRESS_OPTS) -d 1 -c 0 -p $(p) -r $(runtest) -b $(dir) -n $(n) --abs_gausshome=$(abs_gausshome) --single_node --schedule=$(srcdir)/parallel_schedule0$(PART) -w --keep_last_data=${keep_last_data} $(MAXCONNOPT) --temp-config=$(srcdir)/make_fastcheck_postgresql.conf $(EXTRA_TESTS) $(REG_CONF)
fastcheck_single_db4ai: all tablespace-setup
export LD_LIBRARY_PATH=$(SSL_LIB_PATH):$(LD_LIBRARY_PATH) && \
$(pg_regress_check) $(REGRESS_OPTS) -d 1 -c 0 -p $(p) -r $(runtest) -b $(dir) -n $(n) --abs_gausshome=$(abs_gausshome) --single_node --schedule=$(srcdir)/parallel_schedule.db4ai -w --keep_last_data=${keep_last_data} $(MAXCONNOPT) --temp-config=$(srcdir)/make_fastcheck_postgresql.conf $(EXTRA_TESTS) $(REG_CONF)
fastcheck_single_mot: all tablespace-setup
export LD_LIBRARY_PATH=$(SSL_LIB_PATH):$(LD_LIBRARY_PATH) && \
$(pg_regress_check) $(REGRESS_OPTS) -d 1 -c 0 -p $(p) -r $(runtest) -b $(dir) -n $(n) --abs_gausshome=$(abs_gausshome) --single_node --schedule=$(srcdir)/parallel_schedule20 -w --keep_last_data=${keep_last_data} $(MAXCONNOPT) --temp-config=$(srcdir)/make_fastcheck_single_mot_postgresql.conf $(EXTRA_TESTS) $(REG_CONF)

View File

@ -0,0 +1,60 @@
3151, 1,0,0, 0.655000000000000027, 0.505000000000000004, 0.165000000000000008, 1.36699999999999999, 0.583500000000000019, 0.351499999999999979, 0.396000000000000019, 10
2026, 1,0,0, 0.550000000000000044, 0.469999999999999973, 0.149999999999999994, 0.920499999999999985, 0.381000000000000005, 0.243499999999999994, 0.267500000000000016, 10
3751, 0,1,0, 0.434999999999999998, 0.375, 0.110000000000000001, 0.41549999999999998, 0.170000000000000012, 0.0759999999999999981, 0.14499999999999999, 8
720, 0,1,0, 0.149999999999999994, 0.100000000000000006, 0.0250000000000000014, 0.0149999999999999994, 0.00449999999999999966, 0.00400000000000000008, 0.0050000000000000001, 2
1635, 1,0,0, 0.574999999999999956, 0.469999999999999973, 0.154999999999999999, 1.1160000000000001, 0.509000000000000008, 0.237999999999999989, 0.340000000000000024, 10
2648, 0,1,0, 0.5, 0.390000000000000013, 0.125, 0.582999999999999963, 0.293999999999999984, 0.132000000000000006, 0.160500000000000004, 8
1796, 1,0,0, 0.57999999999999996, 0.429999999999999993, 0.170000000000000012, 1.47999999999999998, 0.65349999999999997, 0.32400000000000001, 0.41549999999999998, 10
209, 1,0,0, 0.525000000000000022, 0.41499999999999998, 0.170000000000000012, 0.832500000000000018, 0.275500000000000023, 0.168500000000000011, 0.309999999999999998, 13
1451, 0,1,0, 0.455000000000000016, 0.33500000000000002, 0.135000000000000009, 0.501000000000000001, 0.274000000000000021, 0.0995000000000000051, 0.106499999999999997, 7
1108, 0,1,0, 0.510000000000000009, 0.380000000000000004, 0.115000000000000005, 0.515499999999999958, 0.214999999999999997, 0.113500000000000004, 0.166000000000000009, 8
3675, 1,0,0, 0.594999999999999973, 0.450000000000000011, 0.165000000000000008, 1.08099999999999996, 0.489999999999999991, 0.252500000000000002, 0.279000000000000026, 12
2108, 1,0,0, 0.675000000000000044, 0.550000000000000044, 0.179999999999999993, 1.68849999999999989, 0.562000000000000055, 0.370499999999999996, 0.599999999999999978, 15
3312, 1,0,0, 0.479999999999999982, 0.380000000000000004, 0.135000000000000009, 0.507000000000000006, 0.191500000000000004, 0.13650000000000001, 0.154999999999999999, 12
882, 0,0,1, 0.655000000000000027, 0.520000000000000018, 0.165000000000000008, 1.40949999999999998, 0.585999999999999965, 0.290999999999999981, 0.405000000000000027, 9
3402, 0,0,1, 0.479999999999999982, 0.395000000000000018, 0.149999999999999994, 0.681499999999999995, 0.214499999999999996, 0.140500000000000014, 0.2495, 18
829, 0,1,0, 0.409999999999999976, 0.325000000000000011, 0.100000000000000006, 0.394000000000000017, 0.20799999999999999, 0.0655000000000000027, 0.105999999999999997, 6
1305, 0,0,1, 0.535000000000000031, 0.434999999999999998, 0.149999999999999994, 0.716999999999999971, 0.347499999999999976, 0.14449999999999999, 0.194000000000000006, 9
3613, 0,0,1, 0.599999999999999978, 0.46000000000000002, 0.179999999999999993, 1.1399999999999999, 0.422999999999999987, 0.257500000000000007, 0.364999999999999991, 10
1068, 0,1,0, 0.340000000000000024, 0.265000000000000013, 0.0800000000000000017, 0.201500000000000012, 0.0899999999999999967, 0.0475000000000000006, 0.0550000000000000003, 5
2446, 0,0,1, 0.5, 0.380000000000000004, 0.135000000000000009, 0.583500000000000019, 0.22950000000000001, 0.126500000000000001, 0.179999999999999993, 12
1393, 0,0,1, 0.635000000000000009, 0.474999999999999978, 0.170000000000000012, 1.19350000000000001, 0.520499999999999963, 0.269500000000000017, 0.366499999999999992, 10
359, 0,0,1, 0.744999999999999996, 0.584999999999999964, 0.214999999999999997, 2.49900000000000011, 0.92649999999999999, 0.471999999999999975, 0.699999999999999956, 17
549, 1,0,0, 0.564999999999999947, 0.450000000000000011, 0.160000000000000003, 0.79500000000000004, 0.360499999999999987, 0.155499999999999999, 0.23000000000000001, 12
1154, 1,0,0, 0.599999999999999978, 0.474999999999999978, 0.160000000000000003, 1.02649999999999997, 0.484999999999999987, 0.2495, 0.256500000000000006, 9
1790, 1,0,0, 0.54500000000000004, 0.385000000000000009, 0.149999999999999994, 1.11850000000000005, 0.542499999999999982, 0.244499999999999995, 0.284499999999999975, 9
3703, 1,0,0, 0.665000000000000036, 0.540000000000000036, 0.195000000000000007, 1.76400000000000001, 0.850500000000000034, 0.361499999999999988, 0.469999999999999973, 11
1962, 1,0,0, 0.655000000000000027, 0.515000000000000013, 0.179999999999999993, 1.41199999999999992, 0.619500000000000051, 0.248499999999999999, 0.496999999999999997, 11
1665, 0,1,0, 0.604999999999999982, 0.469999999999999973, 0.14499999999999999, 0.802499999999999991, 0.379000000000000004, 0.226500000000000007, 0.220000000000000001, 9
635, 0,0,1, 0.359999999999999987, 0.294999999999999984, 0.100000000000000006, 0.210499999999999993, 0.0660000000000000031, 0.0524999999999999981, 0.0749999999999999972, 9
3901, 0,0,1, 0.445000000000000007, 0.344999999999999973, 0.140000000000000013, 0.475999999999999979, 0.205499999999999988, 0.101500000000000007, 0.108499999999999999, 15
2734, 0,1,0, 0.41499999999999998, 0.33500000000000002, 0.100000000000000006, 0.357999999999999985, 0.169000000000000011, 0.067000000000000004, 0.104999999999999996, 7
3856, 0,0,1, 0.409999999999999976, 0.33500000000000002, 0.115000000000000005, 0.440500000000000003, 0.190000000000000002, 0.0850000000000000061, 0.135000000000000009, 8
827, 0,1,0, 0.395000000000000018, 0.28999999999999998, 0.0950000000000000011, 0.303999999999999992, 0.127000000000000002, 0.0840000000000000052, 0.076999999999999999, 6
3381, 0,1,0, 0.190000000000000002, 0.130000000000000004, 0.0449999999999999983, 0.0264999999999999993, 0.00899999999999999932, 0.0050000000000000001, 0.00899999999999999932, 5
3972, 0,1,0, 0.400000000000000022, 0.294999999999999984, 0.0950000000000000011, 0.252000000000000002, 0.110500000000000001, 0.0575000000000000025, 0.0660000000000000031, 6
1155, 0,0,1, 0.599999999999999978, 0.455000000000000016, 0.170000000000000012, 1.1915, 0.695999999999999952, 0.239499999999999991, 0.239999999999999991, 8
3467, 0,0,1, 0.640000000000000013, 0.5, 0.170000000000000012, 1.4544999999999999, 0.642000000000000015, 0.357499999999999984, 0.353999999999999981, 9
2433, 1,0,0, 0.609999999999999987, 0.484999999999999987, 0.165000000000000008, 1.08699999999999997, 0.425499999999999989, 0.232000000000000012, 0.380000000000000004, 11
552, 0,1,0, 0.614999999999999991, 0.489999999999999991, 0.154999999999999999, 0.988500000000000045, 0.41449999999999998, 0.195000000000000007, 0.344999999999999973, 13
1425, 1,0,0, 0.729999999999999982, 0.57999999999999996, 0.190000000000000002, 1.73750000000000004, 0.678499999999999992, 0.434499999999999997, 0.520000000000000018, 11
2402, 1,0,0, 0.584999999999999964, 0.41499999999999998, 0.154999999999999999, 0.69850000000000001, 0.299999999999999989, 0.145999999999999991, 0.195000000000000007, 12
1748, 1,0,0, 0.699999999999999956, 0.535000000000000031, 0.174999999999999989, 1.77299999999999991, 0.680499999999999994, 0.479999999999999982, 0.512000000000000011, 15
3983, 0,1,0, 0.57999999999999996, 0.434999999999999998, 0.149999999999999994, 0.891499999999999959, 0.362999999999999989, 0.192500000000000004, 0.251500000000000001, 6
335, 1,0,0, 0.739999999999999991, 0.599999999999999978, 0.195000000000000007, 1.97399999999999998, 0.597999999999999976, 0.408499999999999974, 0.709999999999999964, 16
1587, 0,1,0, 0.515000000000000013, 0.349999999999999978, 0.104999999999999996, 0.474499999999999977, 0.212999999999999995, 0.122999999999999998, 0.127500000000000002, 10
2448, 0,1,0, 0.275000000000000022, 0.204999999999999988, 0.0800000000000000017, 0.096000000000000002, 0.0359999999999999973, 0.0184999999999999991, 0.0299999999999999989, 6
1362, 1,0,0, 0.604999999999999982, 0.474999999999999978, 0.174999999999999989, 1.07600000000000007, 0.463000000000000023, 0.219500000000000001, 0.33500000000000002, 9
2799, 0,0,1, 0.640000000000000013, 0.484999999999999987, 0.149999999999999994, 1.09800000000000009, 0.519499999999999962, 0.222000000000000003, 0.317500000000000004, 10
1413, 1,0,0, 0.67000000000000004, 0.505000000000000004, 0.174999999999999989, 1.01449999999999996, 0.4375, 0.271000000000000019, 0.3745, 10
1739, 1,0,0, 0.67000000000000004, 0.540000000000000036, 0.195000000000000007, 1.61899999999999999, 0.739999999999999991, 0.330500000000000016, 0.465000000000000024, 11
1152, 0,0,1, 0.584999999999999964, 0.465000000000000024, 0.160000000000000003, 0.955500000000000016, 0.45950000000000002, 0.235999999999999988, 0.265000000000000013, 7
2427, 0,1,0, 0.564999999999999947, 0.434999999999999998, 0.154999999999999999, 0.782000000000000028, 0.271500000000000019, 0.16800000000000001, 0.284999999999999976, 14
1777, 0,0,1, 0.484999999999999987, 0.369999999999999996, 0.154999999999999999, 0.967999999999999972, 0.418999999999999984, 0.245499999999999996, 0.236499999999999988, 9
3294, 0,0,1, 0.574999999999999956, 0.455000000000000016, 0.184999999999999998, 1.15599999999999992, 0.552499999999999991, 0.242999999999999994, 0.294999999999999984, 13
1403, 0,0,1, 0.650000000000000022, 0.510000000000000009, 0.190000000000000002, 1.54200000000000004, 0.715500000000000025, 0.373499999999999999, 0.375, 9
2256, 0,0,1, 0.510000000000000009, 0.395000000000000018, 0.14499999999999999, 0.61850000000000005, 0.215999999999999998, 0.138500000000000012, 0.239999999999999991, 12
3984, 1,0,0, 0.584999999999999964, 0.450000000000000011, 0.125, 0.873999999999999999, 0.354499999999999982, 0.20749999999999999, 0.225000000000000006, 6
1116, 0,0,1, 0.525000000000000022, 0.405000000000000027, 0.119999999999999996, 0.755499999999999949, 0.3755, 0.155499999999999999, 0.201000000000000012, 9
1366, 0,0,1, 0.609999999999999987, 0.474999999999999978, 0.170000000000000012, 1.02649999999999997, 0.434999999999999998, 0.233500000000000013, 0.303499999999999992, 10
3759, 0,1,0, 0.525000000000000022, 0.400000000000000022, 0.140000000000000013, 0.605500000000000038, 0.260500000000000009, 0.107999999999999999, 0.209999999999999992, 9

Some files were not shown because too many files have changed in this diff Show More