forked from openGauss-Ecosystem/openGauss-server
commit
b2cc4b3ee9
|
@ -0,0 +1,128 @@
|
|||
<!--
|
||||
doc/src/sgml/ref/alter_operator.sgml
|
||||
PostgreSQL documentation
|
||||
-->
|
||||
|
||||
<refentry id="SQL-ALTEROPERATOR">
|
||||
<refmeta>
|
||||
<refentrytitle>ALTER OPERATOR</refentrytitle>
|
||||
<manvolnum>7</manvolnum>
|
||||
<refmiscinfo>SQL - Language Statements</refmiscinfo>
|
||||
</refmeta>
|
||||
|
||||
<refnamediv>
|
||||
<refname>ALTER OPERATOR</refname>
|
||||
<refpurpose>change the definition of an operator</refpurpose>
|
||||
</refnamediv>
|
||||
|
||||
<indexterm zone="sql-alteroperator">
|
||||
<primary>ALTER OPERATOR</primary>
|
||||
</indexterm>
|
||||
|
||||
<refsynopsisdiv>
|
||||
<synopsis>
|
||||
ALTER OPERATOR <replaceable>name</replaceable> ( { <replaceable>left_type</replaceable> | NONE } , { <replaceable>right_type</replaceable> | NONE } ) OWNER TO <replaceable>new_owner</replaceable>
|
||||
ALTER OPERATOR <replaceable>name</replaceable> ( { <replaceable>left_type</replaceable> | NONE } , { <replaceable>right_type</replaceable> | NONE } ) SET SCHEMA <replaceable>new_schema</replaceable>
|
||||
</synopsis>
|
||||
</refsynopsisdiv>
|
||||
|
||||
<refsect1>
|
||||
<title>Description</title>
|
||||
|
||||
<para>
|
||||
<command>ALTER OPERATOR</command> changes the definition of
|
||||
an operator. The only currently available functionality is to change the
|
||||
owner of the operator.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
You must own the operator to use <command>ALTER OPERATOR</>.
|
||||
To alter the owner, you must also be a direct or indirect member of the new
|
||||
owning role, and that role must have <literal>CREATE</literal> privilege on
|
||||
the operator's schema. (These restrictions enforce that altering the owner
|
||||
doesn't do anything you couldn't do by dropping and recreating the operator.
|
||||
However, a superuser can alter ownership of any operator anyway.)
|
||||
</para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Parameters</title>
|
||||
|
||||
<variablelist>
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">name</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The name (optionally schema-qualified) of an existing operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">left_type</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The data type of the operator's left operand; write
|
||||
<literal>NONE</literal> if the operator has no left operand.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">right_type</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The data type of the operator's right operand; write
|
||||
<literal>NONE</literal> if the operator has no right operand.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">new_owner</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The new owner of the operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">new_schema</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The new schema for the operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
</variablelist>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Examples</title>
|
||||
|
||||
<para>
|
||||
Change the owner of a custom operator <literal>a @@ b</literal> for type <type>text</type>:
|
||||
<programlisting>
|
||||
ALTER OPERATOR @@ (text, text) OWNER TO joe;
|
||||
</programlisting></para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Compatibility</title>
|
||||
|
||||
<para>
|
||||
There is no <command>ALTER OPERATOR</command> statement in
|
||||
the SQL standard.
|
||||
</para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>See Also</title>
|
||||
|
||||
<simplelist type="inline">
|
||||
<member><xref linkend="sql-createoperator"></member>
|
||||
<member><xref linkend="sql-dropoperator"></member>
|
||||
</simplelist>
|
||||
</refsect1>
|
||||
</refentry>
|
|
@ -0,0 +1,296 @@
|
|||
<!--
|
||||
doc/src/sgml/ref/create_operator.sgml
|
||||
PostgreSQL documentation
|
||||
-->
|
||||
|
||||
<refentry id="SQL-CREATEOPERATOR">
|
||||
<refmeta>
|
||||
<refentrytitle>CREATE OPERATOR</refentrytitle>
|
||||
<manvolnum>7</manvolnum>
|
||||
<refmiscinfo>SQL - Language Statements</refmiscinfo>
|
||||
</refmeta>
|
||||
|
||||
<refnamediv>
|
||||
<refname>CREATE OPERATOR</refname>
|
||||
<refpurpose>define a new operator</refpurpose>
|
||||
</refnamediv>
|
||||
|
||||
<indexterm zone="sql-createoperator">
|
||||
<primary>CREATE OPERATOR</primary>
|
||||
</indexterm>
|
||||
|
||||
<refsynopsisdiv>
|
||||
<synopsis>
|
||||
CREATE OPERATOR <replaceable>name</replaceable> (
|
||||
PROCEDURE = <replaceable class="parameter">function_name</replaceable>
|
||||
[, LEFTARG = <replaceable class="parameter">left_type</replaceable> ] [, RIGHTARG = <replaceable class="parameter">right_type</replaceable> ]
|
||||
[, COMMUTATOR = <replaceable class="parameter">com_op</replaceable> ] [, NEGATOR = <replaceable class="parameter">neg_op</replaceable> ]
|
||||
[, RESTRICT = <replaceable class="parameter">res_proc</replaceable> ] [, JOIN = <replaceable class="parameter">join_proc</replaceable> ]
|
||||
[, HASHES ] [, MERGES ]
|
||||
)
|
||||
</synopsis>
|
||||
</refsynopsisdiv>
|
||||
|
||||
<refsect1>
|
||||
<title>Description</title>
|
||||
|
||||
<para>
|
||||
<command>CREATE OPERATOR</command> defines a new operator,
|
||||
<replaceable class="parameter">name</replaceable>. The user who
|
||||
defines an operator becomes its owner. If a schema name is given
|
||||
then the operator is created in the specified schema. Otherwise it
|
||||
is created in the current schema.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
The operator name is a sequence of up to <symbol>NAMEDATALEN</>-1
|
||||
(63 by default) characters from the following list:
|
||||
<literallayout>
|
||||
+ - * / < > = ~ ! @ # % ^ & | ` ?
|
||||
</literallayout>
|
||||
|
||||
There are a few restrictions on your choice of name:
|
||||
<itemizedlist>
|
||||
<listitem>
|
||||
<para><literal>--</literal> and <literal>/*</literal> cannot appear anywhere in an operator name,
|
||||
since they will be taken as the start of a comment.
|
||||
</para>
|
||||
</listitem>
|
||||
<listitem>
|
||||
<para>
|
||||
A multicharacter operator name cannot end in <literal>+</literal> or
|
||||
<literal>-</literal>,
|
||||
unless the name also contains at least one of these characters:
|
||||
<literallayout>
|
||||
~ ! @ # % ^ & | ` ?
|
||||
</literallayout>
|
||||
For example, <literal>@-</literal> is an allowed operator name,
|
||||
but <literal>*-</literal> is not.
|
||||
This restriction allows <productname>PostgreSQL</productname> to
|
||||
parse SQL-compliant commands without requiring spaces between tokens.
|
||||
</para>
|
||||
</listitem>
|
||||
<listitem>
|
||||
<para>
|
||||
The use of <literal>=></> as an operator name is deprecated. It may
|
||||
be disallowed altogether in a future release.
|
||||
</para>
|
||||
</listitem>
|
||||
</itemizedlist>
|
||||
</para>
|
||||
|
||||
<para>
|
||||
The operator <literal>!=</literal> is mapped to
|
||||
<literal><></literal> on input, so these two names are always
|
||||
equivalent.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
At least one of <literal>LEFTARG</> and <literal>RIGHTARG</> must be defined. For
|
||||
binary operators, both must be defined. For right unary
|
||||
operators, only <literal>LEFTARG</> should be defined, while for left
|
||||
unary operators only <literal>RIGHTARG</> should be defined.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
The <replaceable class="parameter">function_name</replaceable>
|
||||
procedure must have been previously defined using <command>CREATE
|
||||
FUNCTION</command> and must be defined to accept the correct number
|
||||
of arguments (either one or two) of the indicated types.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
The other clauses specify optional operator optimization clauses.
|
||||
Their meaning is detailed in <xref linkend="xoper-optimization">.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
To be able to create an operator, you must have <literal>USAGE</literal>
|
||||
privilege on the argument types and the return type, as well
|
||||
as <literal>EXECUTE</literal> privilege on the underlying function. If a
|
||||
commutator or negator operator is specified, you must own these operators.
|
||||
</para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Parameters</title>
|
||||
|
||||
<variablelist>
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">name</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The name of the operator to be defined. See above for allowable
|
||||
characters. The name can be schema-qualified, for example
|
||||
<literal>CREATE OPERATOR myschema.+ (...)</>. If not, then
|
||||
the operator is created in the current schema. Two operators
|
||||
in the same schema can have the same name if they operate on
|
||||
different data types. This is called
|
||||
<firstterm>overloading</>.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">function_name</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The function used to implement this operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">left_type</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The data type of the operator's left operand, if any.
|
||||
This option would be omitted for a left-unary operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">right_type</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The data type of the operator's right operand, if any.
|
||||
This option would be omitted for a right-unary operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">com_op</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The commutator of this operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">neg_op</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The negator of this operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">res_proc</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The restriction selectivity estimator function for this operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">join_proc</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The join selectivity estimator function for this operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><literal>HASHES</literal></term>
|
||||
<listitem>
|
||||
<para>
|
||||
Indicates this operator can support a hash join.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><literal>MERGES</literal></term>
|
||||
<listitem>
|
||||
<para>
|
||||
Indicates this operator can support a merge join.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
</variablelist>
|
||||
|
||||
<para>
|
||||
To give a schema-qualified operator name in <replaceable
|
||||
class="parameter">com_op</replaceable> or the other optional
|
||||
arguments, use the <literal>OPERATOR()</> syntax, for example:
|
||||
<programlisting>
|
||||
COMMUTATOR = OPERATOR(myschema.===) ,
|
||||
</programlisting></para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Notes</title>
|
||||
|
||||
<para>
|
||||
Refer to <xref linkend="xoper"> for further information.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
It is not possible to specify an operator's lexical precedence in
|
||||
<command>CREATE OPERATOR</>, because the parser's precedence behavior
|
||||
is hard-wired. See <xref linkend="sql-precedence"> for precedence details.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
The obsolete options <literal>SORT1</>, <literal>SORT2</>,
|
||||
<literal>LTCMP</>, and <literal>GTCMP</> were formerly used to
|
||||
specify the names of sort operators associated with a merge-joinable
|
||||
operator. This is no longer necessary, since information about
|
||||
associated operators is found by looking at B-tree operator families
|
||||
instead. If one of these options is given, it is ignored except
|
||||
for implicitly setting <literal>MERGES</> true.
|
||||
</para>
|
||||
|
||||
<para>
|
||||
Use <xref linkend="sql-dropoperator"> to delete user-defined operators
|
||||
from a database. Use <xref linkend="sql-alteroperator"> to modify operators in a
|
||||
database.
|
||||
</para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Examples</title>
|
||||
|
||||
<para>
|
||||
The following command defines a new operator, area-equality, for
|
||||
the data type <type>box</type>:
|
||||
<programlisting>
|
||||
CREATE OPERATOR === (
|
||||
LEFTARG = box,
|
||||
RIGHTARG = box,
|
||||
PROCEDURE = area_equal_procedure,
|
||||
COMMUTATOR = ===,
|
||||
NEGATOR = !==,
|
||||
RESTRICT = area_restriction_procedure,
|
||||
JOIN = area_join_procedure,
|
||||
HASHES, MERGES
|
||||
);
|
||||
</programlisting></para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Compatibility</title>
|
||||
|
||||
<para>
|
||||
<command>CREATE OPERATOR</command> is a
|
||||
<productname>PostgreSQL</productname> extension. There are no
|
||||
provisions for user-defined operators in the SQL standard.
|
||||
</para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>See Also</title>
|
||||
|
||||
<simplelist type="inline">
|
||||
<member><xref linkend="sql-alteroperator"></member>
|
||||
<member><xref linkend="sql-createopclass"></member>
|
||||
<member><xref linkend="sql-dropoperator"></member>
|
||||
</simplelist>
|
||||
</refsect1>
|
||||
</refentry>
|
|
@ -0,0 +1,146 @@
|
|||
<!--
|
||||
doc/src/sgml/ref/drop_operator.sgml
|
||||
PostgreSQL documentation
|
||||
-->
|
||||
|
||||
<refentry id="SQL-DROPOPERATOR">
|
||||
<refmeta>
|
||||
<refentrytitle>DROP OPERATOR</refentrytitle>
|
||||
<manvolnum>7</manvolnum>
|
||||
<refmiscinfo>SQL - Language Statements</refmiscinfo>
|
||||
</refmeta>
|
||||
|
||||
<refnamediv>
|
||||
<refname>DROP OPERATOR</refname>
|
||||
<refpurpose>remove an operator</refpurpose>
|
||||
</refnamediv>
|
||||
|
||||
<indexterm zone="sql-dropoperator">
|
||||
<primary>DROP OPERATOR</primary>
|
||||
</indexterm>
|
||||
|
||||
<refsynopsisdiv>
|
||||
<synopsis>
|
||||
DROP OPERATOR [ IF EXISTS ] <replaceable class="PARAMETER">name</replaceable> ( { <replaceable class="PARAMETER">left_type</replaceable> | NONE } , { <replaceable class="PARAMETER">right_type</replaceable> | NONE } ) [ CASCADE | RESTRICT ]
|
||||
</synopsis>
|
||||
</refsynopsisdiv>
|
||||
|
||||
<refsect1>
|
||||
<title>Description</title>
|
||||
|
||||
<para>
|
||||
<command>DROP OPERATOR</command> drops an existing operator from
|
||||
the database system. To execute this command you must be the owner
|
||||
of the operator.
|
||||
</para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Parameters</title>
|
||||
|
||||
<variablelist>
|
||||
|
||||
<varlistentry>
|
||||
<term><literal>IF EXISTS</literal></term>
|
||||
<listitem>
|
||||
<para>
|
||||
Do not throw an error if the operator does not exist. A notice is issued
|
||||
in this case.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">name</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The name (optionally schema-qualified) of an existing operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">left_type</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The data type of the operator's left operand; write
|
||||
<literal>NONE</literal> if the operator has no left operand.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><replaceable class="parameter">right_type</replaceable></term>
|
||||
<listitem>
|
||||
<para>
|
||||
The data type of the operator's right operand; write
|
||||
<literal>NONE</literal> if the operator has no right operand.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><literal>CASCADE</literal></term>
|
||||
<listitem>
|
||||
<para>
|
||||
Automatically drop objects that depend on the operator.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
|
||||
<varlistentry>
|
||||
<term><literal>RESTRICT</literal></term>
|
||||
<listitem>
|
||||
<para>
|
||||
Refuse to drop the operator if any objects depend on it. This
|
||||
is the default.
|
||||
</para>
|
||||
</listitem>
|
||||
</varlistentry>
|
||||
</variablelist>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Examples</title>
|
||||
|
||||
<para>
|
||||
Remove the power operator <literal>a^b</literal> for type <type>integer</type>:
|
||||
<programlisting>
|
||||
DROP OPERATOR ^ (integer, integer);
|
||||
</programlisting>
|
||||
</para>
|
||||
|
||||
<para>
|
||||
Remove the left unary bitwise complement operator
|
||||
<literal>~b</literal> for type <type>bit</type>:
|
||||
<programlisting>
|
||||
DROP OPERATOR ~ (none, bit);
|
||||
</programlisting>
|
||||
</para>
|
||||
|
||||
<para>
|
||||
Remove the right unary factorial operator <literal>x!</literal>
|
||||
for type <type>bigint</type>:
|
||||
<programlisting>
|
||||
DROP OPERATOR ! (bigint, none);
|
||||
</programlisting></para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>Compatibility</title>
|
||||
|
||||
<para>
|
||||
There is no <command>DROP OPERATOR</command> statement in the SQL standard.
|
||||
</para>
|
||||
</refsect1>
|
||||
|
||||
<refsect1>
|
||||
<title>See Also</title>
|
||||
|
||||
<simplelist type="inline">
|
||||
<member><xref linkend="sql-createoperator"></member>
|
||||
<member><xref linkend="sql-alteroperator"></member>
|
||||
</simplelist>
|
||||
</refsect1>
|
||||
|
||||
</refentry>
|
|
@ -57,7 +57,7 @@ POSTGRES_BKI_SRCS = $(addprefix $(top_srcdir)/src/include/catalog/,\
|
|||
pg_job.h gs_asp.h pg_job_proc.h pg_extension_data_source.h pg_statistic_ext.h pg_object.h pg_synonym.h \
|
||||
toasting.h indexing.h gs_obsscaninfo.h pg_directory.h pg_hashbucket.h gs_global_config.h\
|
||||
pg_streaming_stream.h pg_streaming_cont_query.h pg_streaming_reaper_status.h gs_matview.h\
|
||||
gs_matview_dependency.h pgxc_slice.h gs_opt_model.h\
|
||||
gs_matview_dependency.h pgxc_slice.h gs_opt_model.h gs_model.h\
|
||||
)
|
||||
|
||||
# location of Catalog.pm
|
||||
|
|
|
@ -2088,6 +2088,34 @@
|
|||
"datetimetz_pl", 1,
|
||||
AddBuiltinFunc(_0(1297), _1("datetimetz_pl"), _2(2), _3(true), _4(false), _5(datetimetz_timestamptz), _6(1184), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(0), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, 1082, 1266), _21(NULL), _22(NULL), _23(NULL), _24(NULL), _25("datetimetz_timestamptz"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(NULL), _32(false), _33(NULL), _34('f'))
|
||||
),
|
||||
AddFuncGroup(
|
||||
"db4ai_predict_by_bool", 1,
|
||||
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_BOOL_OID), _1("db4ai_predict_by_bool"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_bool), _6(BOOLOID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_bool"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
|
||||
),
|
||||
AddFuncGroup(
|
||||
"db4ai_predict_by_float4", 1,
|
||||
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_FLOAT4_OID), _1("db4ai_predict_by_float4"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_float4), _6(FLOAT4OID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_float4"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
|
||||
),
|
||||
AddFuncGroup(
|
||||
"db4ai_predict_by_float8", 1,
|
||||
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_FLOAT8_OID), _1("db4ai_predict_by_float8"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_float8), _6(FLOAT8OID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_float8"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
|
||||
),
|
||||
AddFuncGroup(
|
||||
"db4ai_predict_by_int32", 1,
|
||||
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_INT32_OID), _1("db4ai_predict_by_int32"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_int32), _6(INT4OID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_int32"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
|
||||
),
|
||||
AddFuncGroup(
|
||||
"db4ai_predict_by_int64", 1,
|
||||
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_INT64_OID), _1("db4ai_predict_by_int64"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_int64), _6(INT8OID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_int64"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
|
||||
),
|
||||
AddFuncGroup(
|
||||
"db4ai_predict_by_numeric", 1,
|
||||
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_NUMERIC_OID), _1("db4ai_predict_by_numeric"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_numeric), _6(NUMERICOID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_numeric"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
|
||||
),
|
||||
AddFuncGroup(
|
||||
"db4ai_predict_by_text", 1,
|
||||
AddBuiltinFunc(_0(DB4AI_PREDICT_BY_TEXT_OID), _1("db4ai_predict_by_text"), _2(2), _3(false), _4(false), _5(db4ai_predict_by_text), _6(TEXTOID), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(ANYOID), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(2, TEXTOID, ANYOID), _21(2, TEXTOID, ANYOID), _22(2, 'i', 'v'), _23(NULL), _24(NULL), _25("db4ai_predict_by_text"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(false), _32(false), _33("f"), _34('f'))
|
||||
),
|
||||
AddFuncGroup(
|
||||
"dcbrt", 1,
|
||||
AddBuiltinFunc(_0(231), _1("dcbrt"), _2(1), _3(true), _4(false), _5(dcbrt), _6(701), _7(PG_CATALOG_NAMESPACE), _8(BOOTSTRAP_SUPERUSERID), _9(INTERNALlanguageId), _10(1), _11(0), _12(0), _13(0), _14(false), _15(false), _16(false), _17(false), _18('i'), _19(0), _20(1, 701), _21(NULL), _22(NULL), _23(NULL), _24(NULL), _25("dcbrt"), _26(NULL), _27(NULL), _28(NULL), _29(0), _30(false), _31(NULL), _32(false), _33(NULL), _34('f'))
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "access/xact.h"
|
||||
#include "catalog/dependency.h"
|
||||
#include "catalog/gs_matview.h"
|
||||
#include "catalog/gs_model.h"
|
||||
#include "catalog/heap.h"
|
||||
#include "catalog/index.h"
|
||||
#include "catalog/namespace.h"
|
||||
|
@ -1421,7 +1422,9 @@ static void doDeletion(const ObjectAddress* object, int flags)
|
|||
case OCLASS_SYNONYM:
|
||||
RemoveSynonymById(object->objectId);
|
||||
break;
|
||||
|
||||
case OCLASS_DB4AI_MODEL:
|
||||
remove_model_by_oid(object->objectId);
|
||||
break;
|
||||
default:
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_UNRECOGNIZED_NODE_TYPE), errmsg("unrecognized object class: %u", object->classId)));
|
||||
|
@ -2371,6 +2374,8 @@ ObjectClass getObjectClass(const ObjectAddress* object)
|
|||
return OCLASS_GLOBAL_SETTING_ARGS;
|
||||
case ClientLogicColumnSettingsArgsId:
|
||||
return OCLASS_COLUMN_SETTING_ARGS;
|
||||
case ModelRelationId:
|
||||
return OCLASS_DB4AI_MODEL;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "access/sysattr.h"
|
||||
#include "catalog/catalog.h"
|
||||
#include "catalog/indexing.h"
|
||||
#include "catalog/gs_model.h"
|
||||
#include "catalog/objectaddress.h"
|
||||
#include "catalog/pg_authid.h"
|
||||
#include "catalog/pg_cast.h"
|
||||
|
@ -113,6 +114,7 @@ static THR_LOCAL const ObjectPropertyType ObjectProperty[] = {{CastRelationId, C
|
|||
CLAOID,
|
||||
Anum_pg_opclass_opcnamespace,
|
||||
},
|
||||
{ModelRelationId, GsModelOidIndexId, DB4AI_MODEL, InvalidAttrNumber},
|
||||
{OperatorRelationId, OperatorOidIndexId, OPEROID, Anum_pg_operator_oprnamespace},
|
||||
{OperatorFamilyRelationId, OpfamilyOidIndexId, OPFAMILYOID, Anum_pg_opfamily_opfnamespace},
|
||||
{AuthIdRelationId, AuthIdOidIndexId, AUTHOID, InvalidAttrNumber},
|
||||
|
@ -204,6 +206,7 @@ ObjectAddress get_object_address(
|
|||
address = get_object_address_relobject(objtype, objname, &relation, missing_ok);
|
||||
break;
|
||||
case OBJECT_DATABASE:
|
||||
case OBJECT_DB4AI_MODEL:
|
||||
case OBJECT_EXTENSION:
|
||||
case OBJECT_TABLESPACE:
|
||||
case OBJECT_ROLE:
|
||||
|
@ -394,6 +397,9 @@ static ObjectAddress get_object_address_unqualified(ObjectType objtype, List* qu
|
|||
case OBJECT_DATABASE:
|
||||
msg = gettext_noop("database name cannot be qualified");
|
||||
break;
|
||||
case OBJECT_DB4AI_MODEL:
|
||||
msg = gettext_noop("model name cannot be qualified");
|
||||
break;
|
||||
case OBJECT_EXTENSION:
|
||||
msg = gettext_noop("extension name cannot be qualified");
|
||||
break;
|
||||
|
@ -440,6 +446,11 @@ static ObjectAddress get_object_address_unqualified(ObjectType objtype, List* qu
|
|||
address.objectId = get_database_oid(name, missing_ok);
|
||||
address.objectSubId = 0;
|
||||
break;
|
||||
case OBJECT_DB4AI_MODEL:
|
||||
address.classId = ModelRelationId;
|
||||
address.objectId = get_model_oid(name, missing_ok);
|
||||
address.objectSubId = 0;
|
||||
break;
|
||||
case OBJECT_EXTENSION:
|
||||
address.classId = ExtensionRelationId;
|
||||
address.objectId = get_extension_oid(name, missing_ok);
|
||||
|
@ -816,6 +827,7 @@ void check_object_ownership(
|
|||
if (!pg_type_ownercheck(address.objectId, roleid))
|
||||
aclcheck_error_type(ACLCHECK_NO_PRIV, address.objectId);
|
||||
break;
|
||||
case OBJECT_DB4AI_MODEL:
|
||||
case OBJECT_DOMAIN:
|
||||
case OBJECT_ATTRIBUTE:
|
||||
if (!pg_type_ownercheck(address.objectId, roleid))
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "catalog/pg_authid.h"
|
||||
#include "catalog/pg_namespace.h"
|
||||
#include "catalog/pg_proc.h"
|
||||
#include "db4ai/predict_by.h"
|
||||
#include "access/transam.h"
|
||||
#include "utils/fmgroids.h"
|
||||
#include "../utils/pg_builtin_proc.h"
|
||||
|
|
|
@ -5776,6 +5776,20 @@ static DropSynonymStmt* _copyDropSynonymStmt(const DropSynonymStmt* from)
|
|||
return newnode;
|
||||
}
|
||||
|
||||
// DB4AI
|
||||
static CreateModelStmt* _copyCreateModelStmt(const CreateModelStmt* from){
|
||||
CreateModelStmt* newnode = makeNode(CreateModelStmt);
|
||||
COPY_STRING_FIELD(model);
|
||||
COPY_STRING_FIELD(architecture);
|
||||
COPY_NODE_FIELD(hyperparameters);
|
||||
COPY_NODE_FIELD(select_query);
|
||||
COPY_NODE_FIELD(model_features);
|
||||
COPY_NODE_FIELD(model_target);
|
||||
COPY_SCALAR_FIELD(algorithm);
|
||||
|
||||
return newnode;
|
||||
}
|
||||
|
||||
/* ****************************************************************
|
||||
* pg_list.h copy functions
|
||||
* ****************************************************************
|
||||
|
@ -7079,6 +7093,9 @@ void* copyObject(const void* from)
|
|||
case T_DropSynonymStmt:
|
||||
retval = _copyDropSynonymStmt((DropSynonymStmt*)from);
|
||||
break;
|
||||
case T_CreateModelStmt: // DB4AI
|
||||
retval = _copyCreateModelStmt((CreateModelStmt*) from);
|
||||
break;
|
||||
case T_A_Expr:
|
||||
retval = _copyAExpr((A_Expr*)from);
|
||||
break;
|
||||
|
|
|
@ -231,6 +231,9 @@ Oid exprType(const Node* expr)
|
|||
case T_Rownum:
|
||||
type = INT8OID;
|
||||
break;
|
||||
case T_GradientDescentExpr:
|
||||
type = ((const GradientDescentExpr*)expr)->fieldtype;
|
||||
break;
|
||||
default:
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_UNRECOGNIZED_NODE_TYPE), errmsg("unrecognized node type: %d", (int)nodeTag(expr))));
|
||||
|
@ -840,6 +843,9 @@ Oid exprCollation(const Node* expr)
|
|||
case T_PlaceHolderVar:
|
||||
coll = exprCollation((Node*)((const PlaceHolderVar*)expr)->phexpr);
|
||||
break;
|
||||
case T_GradientDescentExpr:
|
||||
coll = InvalidOid;
|
||||
break;
|
||||
default:
|
||||
ereport(
|
||||
ERROR, (errcode(ERRCODE_DATATYPE_MISMATCH), errmsg("unrecognized node type: %d", (int)nodeTag(expr))));
|
||||
|
@ -1549,6 +1555,7 @@ bool expression_tree_walker(Node* node, bool (*walker)(), void* context)
|
|||
case T_Null:
|
||||
case T_PgFdwRemoteInfo:
|
||||
case T_Rownum:
|
||||
case T_GradientDescentExpr:
|
||||
/* primitive node types with no expression subnodes */
|
||||
break;
|
||||
case T_Aggref: {
|
||||
|
|
|
@ -570,7 +570,17 @@ static const TagStr g_tagStrArr[] = {{T_Invalid, "Invalid"},
|
|||
{T_SkewRelInfo, "SkewRelInfo"},
|
||||
{T_SkewColumnInfo, "SkewColumnInfo"},
|
||||
{T_SkewValueInfo, "SkewValueInfo"},
|
||||
{T_QualSkewInfo, "QualSkewInfo"}};
|
||||
{T_QualSkewInfo, "QualSkewInfo"},
|
||||
// DB4AI
|
||||
{T_CreateModelStmt, "CreateModelStmt"},
|
||||
{T_PredictByFunction, "PredictByFunction"},
|
||||
{T_GradientDescent, "GradientDescent"},
|
||||
{T_GradientDescentState, "GradientDescentState"},
|
||||
{T_KMeans, "Kmeans"},
|
||||
{T_KMeansState, "KmeansState"},
|
||||
{T_GradientDescentExpr, "GradientDescentExpr"},
|
||||
{T_GradientDescentExprState, "GradientDescentExprState"},
|
||||
};
|
||||
|
||||
char* nodeTagToString(NodeTag tag)
|
||||
{
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
#include "pgxc/pgxc.h"
|
||||
#include "pgxc/pgFdwRemote.h"
|
||||
#endif
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
/*
|
||||
* Macros to simplify output of different kinds of fields. Use these
|
||||
|
@ -5129,6 +5130,33 @@ static void _outIndexVar(StringInfo str, IndexVar* node)
|
|||
WRITE_BOOL_FIELD(indexpath);
|
||||
}
|
||||
|
||||
static void _outGradientDescent(StringInfo str, GradientDescent* node)
|
||||
{
|
||||
WRITE_NODE_TYPE("SGD");
|
||||
_outPlanInfo(str, (Plan*)node);
|
||||
appendStringInfoString(str, " :algorithm ");
|
||||
appendStringInfoString(str, gd_get_algorithm(node->algorithm)->name);
|
||||
appendStringInfoString(str, " :optimizer ");
|
||||
appendStringInfoString(str, gd_get_optimizer_name(node->optimizer));
|
||||
WRITE_INT_FIELD(targetcol);
|
||||
WRITE_INT_FIELD(max_iterations);
|
||||
WRITE_INT_FIELD(max_seconds);
|
||||
WRITE_INT_FIELD(batch_size);
|
||||
WRITE_BOOL_FIELD(verbose);
|
||||
WRITE_FLOAT_FIELD(learning_rate, "%.16g");
|
||||
WRITE_FLOAT_FIELD(decay, "%.16g");
|
||||
WRITE_FLOAT_FIELD(tolerance, "%.16g");
|
||||
WRITE_INT_FIELD(seed);
|
||||
WRITE_FLOAT_FIELD(lambda, "%.16g");
|
||||
}
|
||||
|
||||
static void _outGradientDescentExpr(StringInfo str, GradientDescentExpr* node)
|
||||
{
|
||||
WRITE_NODE_TYPE("GradientDescentExpr");
|
||||
WRITE_UINT_FIELD(field);
|
||||
WRITE_OID_FIELD(fieldtype);
|
||||
}
|
||||
|
||||
/*
|
||||
* _outNode -
|
||||
* converts a Node into ascii string and append it to 'str'
|
||||
|
@ -5945,6 +5973,11 @@ static void _outNode(StringInfo str, const void* obj)
|
|||
break;
|
||||
case T_RewriteHint:
|
||||
_outRewriteHint(str, (RewriteHint *)obj);
|
||||
case T_GradientDescent:
|
||||
_outGradientDescent(str, (GradientDescent*)obj);
|
||||
break;
|
||||
case T_GradientDescentExpr:
|
||||
_outGradientDescentExpr(str, (GradientDescentExpr*)obj);
|
||||
break;
|
||||
default:
|
||||
|
||||
|
|
|
@ -80,6 +80,10 @@
|
|||
#include "instruments/instr_unique_sql.h"
|
||||
#include "streaming/init.h"
|
||||
|
||||
#include "db4ai/aifuncs.h"
|
||||
#include "db4ai/create_model.h"
|
||||
#include "db4ai/hyperparameter_validation.h"
|
||||
|
||||
#ifndef ENABLE_MULTIPLE_NODES
|
||||
#include "optimizer/clauses.h"
|
||||
#endif
|
||||
|
@ -274,6 +278,77 @@ Query* transformTopLevelStmt(ParseState* pstate, Node* parseTree, bool isFirstNo
|
|||
return transformStmt(pstate, parseTree, isFirstNode, isCreateView);
|
||||
}
|
||||
|
||||
|
||||
Query* transformCreateModelStmt(ParseState* pstate, CreateModelStmt* stmt)
|
||||
{
|
||||
SelectStmt* select_stmt = (SelectStmt*) stmt->select_query;
|
||||
|
||||
stmt->algorithm = get_algorithm_ml(stmt->architecture);
|
||||
if (stmt->algorithm == INVALID_ALGORITHM_ML) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Non recognized ML model architecture definition %s", stmt->architecture)));
|
||||
}
|
||||
if (SearchSysCacheExists1(DB4AI_MODEL, CStringGetDatum(stmt->model))) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("The model name \"%s\" already exists in gs_model_warehouse.", stmt->model)));
|
||||
}
|
||||
|
||||
// Create the projection for the AI operator in the query plan
|
||||
// If the algorithm is supervised, the target is always the first element of the list
|
||||
if (is_supervised(stmt->algorithm)) {
|
||||
if (list_length(stmt->model_features) == 0) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Supervised ML algorithms require FEATURES clause")));
|
||||
}else if (list_length(stmt->model_target) == 0) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Supervised ML algorithms require TARGET clause")));
|
||||
}else if (list_length(stmt->model_target) > 1) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Target clause only supports one expression")));
|
||||
}
|
||||
}else{
|
||||
if (list_length(stmt->model_target) > 0) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Unsupervised ML algorithms cannot have TARGET clause")));
|
||||
}
|
||||
}
|
||||
|
||||
select_stmt->targetList = NULL;
|
||||
foreach_cell(it, stmt->model_target) {
|
||||
select_stmt->targetList = lappend(select_stmt->targetList, lfirst(it));
|
||||
}
|
||||
|
||||
if (list_length(stmt->model_features) > 0) { // User given projection
|
||||
foreach_cell(it, stmt->model_features) {
|
||||
select_stmt->targetList = lappend(select_stmt->targetList, lfirst(it));
|
||||
}
|
||||
|
||||
} else { // No projection
|
||||
ResTarget *rt = makeNode(ResTarget);
|
||||
ColumnRef *cr = makeNode(ColumnRef);
|
||||
|
||||
cr->fields = list_make1(makeNode(A_Star));
|
||||
cr->location = -1;
|
||||
|
||||
rt->name = NULL;
|
||||
rt->indirection = NIL;
|
||||
rt->val = (Node *)cr;
|
||||
rt->location = -1;
|
||||
select_stmt->targetList = lappend(select_stmt->targetList, rt);
|
||||
}
|
||||
|
||||
// Transform the select query that we prepared for the training operator
|
||||
Query* select_query = transformStmt(pstate, (Node*) select_stmt);
|
||||
stmt->select_query = (Node*) select_query;
|
||||
|
||||
/* represent the command as a utility Query */
|
||||
Query* result = makeNode(Query);
|
||||
result->commandType = CMD_UTILITY;
|
||||
result->utilityStmt = (Node*)stmt;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* transformStmt -
|
||||
* recursively transform a Parse tree into a Query tree.
|
||||
|
@ -333,6 +408,11 @@ Query* transformStmt(ParseState* pstate, Node* parseTree, bool isFirstNode, bool
|
|||
result = transformCreateTableAsStmt(pstate, (CreateTableAsStmt*)parseTree);
|
||||
break;
|
||||
|
||||
case T_CreateModelStmt:
|
||||
result = transformCreateModelStmt(pstate, (CreateModelStmt*) parseTree);
|
||||
break;
|
||||
|
||||
|
||||
default:
|
||||
|
||||
/*
|
||||
|
|
|
@ -214,6 +214,7 @@ bool IsValidGroupname(const char *input);
|
|||
static bool checkNlssortArgs(const char *argname);
|
||||
static void ParseUpdateMultiSet(List *set_target_list, SelectStmt *stmt, core_yyscan_t yyscanner);
|
||||
static void parameter_check_execute_direct(const char* query);
|
||||
|
||||
%}
|
||||
|
||||
%pure-parser
|
||||
|
@ -316,6 +317,13 @@ static void parameter_check_execute_direct(const char* query);
|
|||
MergeStmt CreateMatViewStmt RefreshMatViewStmt
|
||||
CreateWeakPasswordDictionaryStmt DropWeakPasswordDictionaryStmt
|
||||
|
||||
/* <DB4AI> */
|
||||
|
||||
/* TRAIN MODEL */
|
||||
%type <node> CreateModelStmt hyperparameter_name_value DropModelStmt
|
||||
%type <list> features_clause hyperparameter_name_value_list target_clause with_hyperparameters_clause
|
||||
/* </DB4AI> */
|
||||
|
||||
%type <node> select_no_parens select_with_parens select_clause
|
||||
simple_select values_clause
|
||||
|
||||
|
@ -687,7 +695,7 @@ static void parameter_check_execute_direct(const char* query);
|
|||
* DOT_DOT is unused in the core SQL grammar, and so will always provoke
|
||||
* parse errors. It is needed by PL/pgsql.
|
||||
*/
|
||||
%token <str> IDENT FCONST SCONST BCONST XCONST Op CmpOp COMMENTSTRING
|
||||
%token <str> IDENT FCONST SCONST BCONST VCONST XCONST Op CmpOp COMMENTSTRING
|
||||
%token <ival> ICONST PARAM
|
||||
%token TYPECAST ORA_JOINOP DOT_DOT COLON_EQUALS PARA_EQUALS
|
||||
|
||||
|
@ -701,8 +709,8 @@ static void parameter_check_execute_direct(const char* query);
|
|||
/* ordinary key words in alphabetical order */
|
||||
/* PGXC - added DISTRIBUTE, DIRECT, COORDINATOR, CLEAN, NODE, BARRIER, SLICE, DATANODE */
|
||||
%token <keyword> ABORT_P ABSOLUTE_P ACCESS ACCOUNT ACTION ADD_P ADMIN AFTER
|
||||
AGGREGATE ALGORITHM ALL ALSO ALTER ALWAYS ANALYSE ANALYZE AND ANY APP ARRAY AS ASC
|
||||
ASSERTION ASSIGNMENT ASYMMETRIC AT ATTRIBUTE AUDIT AUTHID AUTHORIZATION AUTOEXTEND AUTOMAPPED
|
||||
AGGREGATE ALGORITHM ALL ALSO ALTER ALWAYS ANALYSE ANALYZE AND ANY APP ARCHIVE ARRAY AS ASC
|
||||
ASSERTION ASSIGNMENT ASYMMETRIC AT ATTRIBUTE AUDIT AUTHID AUTHORIZATION AUTOEXTEND AUTOMAPPED AUTOML
|
||||
|
||||
BACKWARD BARRIER BEFORE BEGIN_NON_ANOYBLOCK BEGIN_P BETWEEN BIGINT BINARY BINARY_DOUBLE BINARY_INTEGER BIT BLOB_P BOGUS
|
||||
BOOLEAN_P BOTH BUCKETS BY BYTEAWITHOUTORDER BYTEAWITHOUTORDERWITHEQUAL
|
||||
|
@ -728,6 +736,7 @@ static void parameter_check_execute_direct(const char* query);
|
|||
EXTENSION EXTERNAL EXTRACT
|
||||
|
||||
FALSE_P FAMILY FAST FENCED FETCH FILEHEADER_P FILL_MISSING_FIELDS FILTER FIRST_P FIXED_P FLOAT_P FOLLOWING FOR FORCE FOREIGN FORMATTER FORWARD
|
||||
FEATURES // DB4AI
|
||||
FREEZE FROM FULL FUNCTION FUNCTIONS
|
||||
|
||||
GENERATED GLOBAL GLOBAL_FUNCTION GRANT GRANTED GREATEST GROUP_P GROUPING_P
|
||||
|
@ -748,6 +757,7 @@ static void parameter_check_execute_direct(const char* query);
|
|||
LEAST LESS LEFT LEVEL LIKE LIMIT LIST LISTEN LOAD LOCAL LOCALTIME LOCALTIMESTAMP
|
||||
LOCATION LOCK_P LOG_P LOGGING LOGIN_ANY LOGIN_FAILURE LOGIN_SUCCESS LOGOUT LOOP
|
||||
MAPPING MASKING MASTER MATCH MATERIALIZED MATCHED MAXEXTENTS MAXSIZE MAXTRANS MAXVALUE MERGE MINUS_P MINUTE_P MINVALUE MINEXTENTS MODE MODIFY_P MONTH_P MOVE MOVEMENT
|
||||
MODEL // DB4AI
|
||||
NAME_P NAMES NATIONAL NATURAL NCHAR NEXT NLSSORT NO NOCOMPRESS NOCYCLE NODE NOLOGGING NOMAXVALUE NOMINVALUE NONE
|
||||
NOT NOTHING NOTIFY NOTNULL NOWAIT NULL_P NULLIF NULLS_P NUMBER_P NUMERIC NUMSTR NVARCHAR2 NVL
|
||||
|
||||
|
@ -756,24 +766,28 @@ static void parameter_check_execute_direct(const char* query);
|
|||
|
||||
PACKAGE PARSER PARTIAL PARTITION PARTITIONS PASSING PASSWORD PCTFREE PER_P PERCENT PERFORMANCE PERM PLACING PLAN PLANS POLICY POSITION
|
||||
/* PGXC_BEGIN */
|
||||
POOL PRECEDING PRECISION PREFERRED PREFIX PRESERVE PREPARE PREPARED PRIMARY
|
||||
POOL PRECEDING PRECISION
|
||||
/* PGXC_END */
|
||||
PRIVATE PRIOR PRIVILEGES PRIVILEGE PROCEDURAL PROCEDURE PROFILE
|
||||
PREDICT // DB4AI
|
||||
/* PGXC_BEGIN */
|
||||
PREFERRED PREFIX PRESERVE PREPARE PREPARED PRIMARY
|
||||
/* PGXC_END */
|
||||
PRIVATE PRIOR PRIVILEGES PRIVILEGE PROCEDURAL PROCEDURE PROFILE PUBLISH PURGE
|
||||
|
||||
QUERY QUOTE
|
||||
|
||||
RANDOMIZED RANGE RAW READ REAL REASSIGN REBUILD RECHECK RECURSIVE REDISANYVALUE REF REFERENCES REFRESH REINDEX REJECT_P
|
||||
RANDOMIZED RANGE RATIO RAW READ REAL REASSIGN REBUILD RECHECK RECURSIVE REDISANYVALUE REF REFERENCES REFRESH REINDEX REJECT_P
|
||||
RELATIVE_P RELEASE RELOPTIONS REMOTE_P REMOVE RENAME REPEATABLE REPLACE REPLICA
|
||||
RESET RESIZE RESOURCE RESTART RESTRICT RETURN RETURNING RETURNS REUSE REVOKE RIGHT ROLE ROLES ROLLBACK ROLLUP
|
||||
ROW ROWNUM ROWS RULE
|
||||
|
||||
SAVEPOINT SCHEMA SCROLL SEARCH SECOND_P SECURITY SELECT SEQUENCE SEQUENCES
|
||||
SAMPLE SAVEPOINT SCHEMA SCROLL SEARCH SECOND_P SECURITY SELECT SEQUENCE SEQUENCES
|
||||
SERIALIZABLE SERVER SESSION SESSION_USER SET SETS SETOF SHARE SHIPPABLE SHOW SHUTDOWN
|
||||
SIMILAR SIMPLE SIZE SLICE SMALLDATETIME SMALLDATETIME_FORMAT_P SMALLINT SNAPSHOT SOME SOURCE_P SPACE SPILL SPLIT STABLE STANDALONE_P START
|
||||
STATEMENT STATEMENT_ID STATISTICS STDIN STDOUT STORAGE STORE_P STORED STREAM STRICT_P STRIP_P SUBSTRING
|
||||
STATEMENT STATEMENT_ID STATISTICS STDIN STDOUT STORAGE STORE_P STORED STRATIFY STREAM STRICT_P STRIP_P SUBSTRING
|
||||
SYMMETRIC SYNONYM SYSDATE SYSID SYSTEM_P SYS_REFCURSOR
|
||||
|
||||
TABLE TABLES TABLESAMPLE TABLESPACE TEMP TEMPLATE TEMPORARY TEXT_P THAN THEN TIME TIME_FORMAT_P TIMESTAMP TIMESTAMP_FORMAT_P TIMESTAMPDIFF TINYINT
|
||||
TABLE TABLES TABLESAMPLE TABLESPACE TARGET TEMP TEMPLATE TEMPORARY TEXT_P THAN THEN TIME TIME_FORMAT_P TIMESTAMP TIMESTAMP_FORMAT_P TIMESTAMPDIFF TINYINT
|
||||
TO TRAILING TRANSACTION TREAT TRIGGER TRIM TRUE_P
|
||||
TRUNCATE TRUSTED TSFIELD TSTAG TSTIME TYPE_P TYPES_P
|
||||
|
||||
|
@ -812,6 +826,7 @@ static void parameter_check_execute_direct(const char* query);
|
|||
%nonassoc PARTIAL_EMPTY_PREC
|
||||
%nonassoc CLUSTER
|
||||
%nonassoc SET /* see relation_expr_opt_alias */
|
||||
%right FEATURES TARGET // DB4AI
|
||||
%left UNION EXCEPT MINUS_P
|
||||
%left INTERSECT
|
||||
%left OR
|
||||
|
@ -852,7 +867,7 @@ static void parameter_check_execute_direct(const char* query);
|
|||
*/
|
||||
%nonassoc UNBOUNDED /* ideally should have same precedence as IDENT */
|
||||
%nonassoc IDENT GENERATED NULL_P PARTITION RANGE ROWS PRECEDING FOLLOWING CUBE ROLLUP
|
||||
%left Op OPERATOR /* multi-character ops and user-defined operators */
|
||||
%left Op OPERATOR '@' /* multi-character ops and user-defined operators */
|
||||
%nonassoc NOTNULL
|
||||
%nonassoc ISNULL
|
||||
%nonassoc IS /* sets precedence for IS NULL, etc */
|
||||
|
@ -987,6 +1002,7 @@ stmt :
|
|||
| CreateFunctionStmt
|
||||
| CreateGroupStmt
|
||||
| CreateMatViewStmt
|
||||
| CreateModelStmt // DB4AI
|
||||
| CreateNodeGroupStmt
|
||||
| CreateNodeStmt
|
||||
| CreateOpClassStmt
|
||||
|
@ -1033,6 +1049,7 @@ stmt :
|
|||
| DropFdwStmt
|
||||
| DropForeignServerStmt
|
||||
| DropGroupStmt
|
||||
| DropModelStmt // DB4AI
|
||||
| DropNodeGroupStmt
|
||||
| DropNodeStmt
|
||||
| DropOpClassStmt
|
||||
|
@ -7585,6 +7602,133 @@ AlterUserMappingStmt: ALTER USER MAPPING FOR auth_ident SERVER name alter_generi
|
|||
}
|
||||
;
|
||||
|
||||
|
||||
/*****************************************************************************
|
||||
*
|
||||
* QUERY:
|
||||
* CREATE MODEL <model_name>
|
||||
*
|
||||
*
|
||||
*****************************************************************************/
|
||||
|
||||
|
||||
CreateModelStmt:
|
||||
CREATE MODEL ColId
|
||||
USING ColId
|
||||
features_clause
|
||||
target_clause
|
||||
from_clause
|
||||
with_hyperparameters_clause
|
||||
{
|
||||
CreateModelStmt *n = makeNode(CreateModelStmt);
|
||||
n->model = pstrdup($3);
|
||||
n->architecture = pstrdup($5);
|
||||
n->model_features = $6;
|
||||
n->model_target = $7;
|
||||
|
||||
// The clause will be constructed in tranform
|
||||
SelectStmt *s = makeNode(SelectStmt);
|
||||
s->fromClause = $8;
|
||||
|
||||
n->select_query = (Node*) s;
|
||||
n->hyperparameters = $9;
|
||||
|
||||
$$ = (Node*) n;
|
||||
}
|
||||
;
|
||||
|
||||
features_clause:
|
||||
FEATURES target_list{
|
||||
List* result = $2;
|
||||
|
||||
// Verify that target clause is not '*'
|
||||
foreach_cell(it, result){
|
||||
ResTarget* n = (ResTarget*) lfirst(it);
|
||||
ColumnRef* cr = (n->val != NULL && IsA(n->val, ColumnRef)) ? (ColumnRef*)(n->val) : NULL;
|
||||
List* l = (cr != NULL) ? cr->fields : NULL;
|
||||
Node* node = list_length(l) > 0 ? linitial_node(Node, l) : NULL;
|
||||
if (node != NULL && IsA(node, A_Star)){
|
||||
elog(ERROR, "FEATURES clause cannot be *");
|
||||
}
|
||||
}
|
||||
|
||||
$$ = result;
|
||||
}
|
||||
| {
|
||||
List* result = NULL;
|
||||
$$ = result;
|
||||
}
|
||||
|
||||
target_clause:
|
||||
TARGET target_list{
|
||||
List* result = $2;
|
||||
|
||||
// Verify that target clause is not '*'
|
||||
foreach_cell(it, result){
|
||||
ResTarget* n = (ResTarget*) lfirst(it);
|
||||
ColumnRef* cr = (n->val != NULL && IsA(n->val, ColumnRef)) ? (ColumnRef*)(n->val) : NULL;
|
||||
List* l = (cr != NULL) ? cr->fields : NULL;
|
||||
Node* node = list_length(l) > 0 ? linitial_node(Node, l) : NULL;
|
||||
if (node != NULL && IsA(node, A_Star)){
|
||||
elog(ERROR, "TARGET clause cannot be *");
|
||||
}
|
||||
}
|
||||
|
||||
$$ = result;
|
||||
}
|
||||
| {
|
||||
List* result = NULL;
|
||||
$$ = result;
|
||||
}
|
||||
|
||||
|
||||
with_hyperparameters_clause:
|
||||
WITH hyperparameter_name_value_list { $$ = $2; }
|
||||
| { $$ = NULL; }
|
||||
;
|
||||
|
||||
hyperparameter_name_value_list:
|
||||
hyperparameter_name_value { $$ = list_make1($1); }
|
||||
| hyperparameter_name_value_list ',' hyperparameter_name_value
|
||||
{
|
||||
$$ = lappend($1,$3);
|
||||
}
|
||||
;
|
||||
|
||||
hyperparameter_name_value:
|
||||
ColLabel '=' var_value
|
||||
{
|
||||
VariableSetStmt *n = makeNode(VariableSetStmt);
|
||||
n->kind = VAR_SET_VALUE;
|
||||
n->name = $1;
|
||||
n->args = list_make1($3);
|
||||
$$ = (Node*) n;
|
||||
}
|
||||
| ColLabel '=' DEFAULT
|
||||
{
|
||||
VariableSetStmt *n = makeNode(VariableSetStmt);
|
||||
n->kind = VAR_SET_DEFAULT;
|
||||
n->name = $1;
|
||||
n->args = NULL;
|
||||
$$ = (Node*) n;
|
||||
}
|
||||
;
|
||||
|
||||
DropModelStmt:
|
||||
DROP MODEL ColId opt_drop_behavior
|
||||
{
|
||||
DropStmt *n = makeNode(DropStmt);
|
||||
n->removeType = OBJECT_DB4AI_MODEL;
|
||||
n->objects = list_make1(list_make1(makeString($3)));
|
||||
n->arguments = NULL;
|
||||
n->behavior = $4;
|
||||
n->missing_ok = false;
|
||||
n->concurrent = false;
|
||||
n->isProcedure = false;
|
||||
$$ = (Node *)n;
|
||||
}
|
||||
;
|
||||
|
||||
/*****************************************************************************
|
||||
*
|
||||
* QUERIES For ROW LEVEL SECURITY:
|
||||
|
@ -14823,6 +14967,7 @@ ExplainableStmt:
|
|||
| MergeStmt
|
||||
| DeclareCursorStmt
|
||||
| CreateAsStmt
|
||||
| CreateModelStmt
|
||||
| ExecuteStmt /* by default all are $$=$1 */
|
||||
;
|
||||
|
||||
|
@ -17840,6 +17985,15 @@ a_expr: c_expr { $$ = $1; }
|
|||
list_make1($1), @2),
|
||||
@2);
|
||||
}
|
||||
| PREDICT BY ColId '(' FEATURES func_arg_list ')'
|
||||
{
|
||||
PredictByFunction * n = makeNode(PredictByFunction);
|
||||
n->model_name = $3;
|
||||
n->model_name_location = @3;
|
||||
n->model_args = $6;
|
||||
n->model_args_location = @6;
|
||||
$$ = (Node*) n;
|
||||
}
|
||||
;
|
||||
|
||||
/*
|
||||
|
@ -20191,6 +20345,7 @@ unreserved_keyword:
|
|||
| EXTERNAL
|
||||
| FAMILY
|
||||
| FAST
|
||||
| FEATURES // DB4AI
|
||||
| FILEHEADER_P
|
||||
| FILL_MISSING_FIELDS
|
||||
| FILTER
|
||||
|
@ -20276,6 +20431,7 @@ unreserved_keyword:
|
|||
| MINUTE_P
|
||||
| MINVALUE
|
||||
| MODE
|
||||
| MODEL // DB4AI
|
||||
| MONTH_P
|
||||
| MOVE
|
||||
| MOVEMENT
|
||||
|
@ -20321,6 +20477,7 @@ unreserved_keyword:
|
|||
| POLICY
|
||||
| POOL
|
||||
| PRECEDING
|
||||
| PREDICT // DB4AI
|
||||
/* PGXC_BEGIN */
|
||||
| PREFERRED
|
||||
/* PGXC_END */
|
||||
|
@ -20338,6 +20495,7 @@ unreserved_keyword:
|
|||
| QUOTE
|
||||
| RANDOMIZED
|
||||
| RANGE
|
||||
| RATIO
|
||||
| RAW '(' Iconst ')' { $$ = "raw";}
|
||||
| RAW %prec UNION { $$ = "raw";}
|
||||
| READ
|
||||
|
@ -20373,6 +20531,7 @@ unreserved_keyword:
|
|||
| ROLLUP
|
||||
| ROWS
|
||||
| RULE
|
||||
| SAMPLE
|
||||
| SAVEPOINT
|
||||
| SCHEMA
|
||||
| SCROLL
|
||||
|
@ -20410,6 +20569,7 @@ unreserved_keyword:
|
|||
| STORAGE
|
||||
| STORE_P
|
||||
| STORED
|
||||
| STRATIFY
|
||||
| STREAM
|
||||
| STRICT_P
|
||||
| STRIP_P
|
||||
|
@ -20419,6 +20579,7 @@ unreserved_keyword:
|
|||
| SYSTEM_P
|
||||
| TABLES
|
||||
| TABLESPACE
|
||||
| TARGET
|
||||
| TEMP
|
||||
| TEMPLATE
|
||||
| TEMPORARY
|
||||
|
@ -22198,6 +22359,7 @@ static void parameter_check_execute_direct(const char* query)
|
|||
errmsg("must be system admin or monitor admin to use EXECUTE DIRECT")));
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Must undefine this stuff before including scan.c, since it has different
|
||||
* definitions for these macros.
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "catalog/pg_proc.h"
|
||||
#include "commands/dbcommands.h"
|
||||
#include "commands/sequence.h"
|
||||
#include "db4ai/predict_by.h"
|
||||
#include "foreign/foreign.h"
|
||||
#include "miscadmin.h"
|
||||
#include "nodes/makefuncs.h"
|
||||
|
@ -67,6 +68,7 @@ static Node* transformXmlExpr(ParseState* pstate, XmlExpr* x);
|
|||
static Node* transformXmlSerialize(ParseState* pstate, XmlSerialize* xs);
|
||||
static Node* transformBooleanTest(ParseState* pstate, BooleanTest* b);
|
||||
static Node* transformCurrentOfExpr(ParseState* pstate, CurrentOfExpr* cexpr);
|
||||
static Node* transformPredictByFunction(ParseState* pstate, PredictByFunction* cexpr);
|
||||
static Node* transformColumnRef(ParseState* pstate, ColumnRef* cref);
|
||||
static Node* transformWholeRowRef(ParseState* pstate, RangeTblEntry* rte, int location);
|
||||
static Node* transformIndirection(ParseState* pstate, Node* basenode, List* indirection);
|
||||
|
@ -286,6 +288,10 @@ Node* transformExpr(ParseState* pstate, Node* expr)
|
|||
result = transformCurrentOfExpr(pstate, (CurrentOfExpr*)expr);
|
||||
break;
|
||||
|
||||
case T_PredictByFunction:
|
||||
result = transformPredictByFunction(pstate, (PredictByFunction*) expr);
|
||||
break;
|
||||
|
||||
/*********************************************
|
||||
* Quietly accept node types that may be presented when we are
|
||||
* called on an already-transformed tree.
|
||||
|
@ -2027,6 +2033,102 @@ static Node* transformCurrentOfExpr(ParseState* pstate, CurrentOfExpr* cexpr)
|
|||
return (Node*)cexpr;
|
||||
}
|
||||
|
||||
// Locate in the system catalog the information for a model name
|
||||
static char* select_prediction_function(Model* model){
|
||||
|
||||
char* result;
|
||||
switch(model->return_type){
|
||||
case BOOLOID:
|
||||
result = "db4ai_predict_by_bool";
|
||||
break;
|
||||
case FLOAT4OID:
|
||||
result = "db4ai_predict_by_float4";
|
||||
break;
|
||||
case FLOAT8OID:
|
||||
result = "db4ai_predict_by_float8";
|
||||
break;
|
||||
case INT1OID:
|
||||
case INT2OID:
|
||||
case INT4OID:
|
||||
result = "db4ai_predict_by_int32";
|
||||
break;
|
||||
case INT8OID:
|
||||
result = "db4ai_predict_by_int64";
|
||||
break;
|
||||
case NUMERICOID:
|
||||
result = "db4ai_predict_by_numeric";
|
||||
break;
|
||||
case VARCHAROID:
|
||||
case BPCHAROID:
|
||||
case CHAROID:
|
||||
case TEXTOID:
|
||||
result = "db4ai_predict_by_text";
|
||||
break;
|
||||
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Cannot trigger prediction for model with oid %d", model->return_type)));
|
||||
result = NULL;
|
||||
break;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Convert the PredictByFunction created during parsing phase into the function
|
||||
// call that computes the prediction of the model. We cannot do this during parsing
|
||||
// because at that moment we do not know the return type of the model, which is obtained
|
||||
// from the catalog
|
||||
static Node* transformPredictByFunction(ParseState* pstate, PredictByFunction* p)
|
||||
{
|
||||
FuncCall* n = makeNode(FuncCall);
|
||||
|
||||
if (p->model_name == NULL) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Model name for prediction cannot be null")));
|
||||
}
|
||||
|
||||
Model* model = get_model(p->model_name, true);
|
||||
if (model == NULL) {
|
||||
ereport(ERROR, (errmsg(
|
||||
"No model found with name %s", p->model_name)));
|
||||
}
|
||||
|
||||
// Locate the proper function according to the model name
|
||||
char* function_name = select_prediction_function(model);
|
||||
ereport(DEBUG1, (errmsg(
|
||||
"Selecting prediction function %s for model %s",
|
||||
function_name, p->model_name)));
|
||||
n->funcname = list_make1(makeString(function_name));
|
||||
n->colname = p->model_name;
|
||||
|
||||
// Fill model name parameter
|
||||
A_Const* model_name_aconst = makeNode(A_Const);
|
||||
model_name_aconst->val.type = T_String;
|
||||
model_name_aconst->val.val.str = p->model_name;
|
||||
model_name_aconst->location = p->model_name_location;
|
||||
|
||||
// Copy other parameters
|
||||
n->args = list_make1(model_name_aconst);
|
||||
if (list_length(p->model_args) > 0) {
|
||||
n->args = lappend3(n->args, p->model_args);
|
||||
}else{
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Innput features for the model not specified")));
|
||||
}
|
||||
|
||||
n->agg_order = NULL;
|
||||
n->agg_star = FALSE;
|
||||
n->agg_distinct = FALSE;
|
||||
n->func_variadic = FALSE;
|
||||
n->over = NULL;
|
||||
n->location = p->model_args_location;
|
||||
n->call_func = false;
|
||||
|
||||
return transformExpr(pstate, (Node*)n);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Construct a whole-row reference to represent the notation "relation.*".
|
||||
*/
|
||||
|
|
|
@ -1495,6 +1495,14 @@ static int FigureColnameInternal(Node* node, char** name)
|
|||
return 2;
|
||||
}
|
||||
break;
|
||||
case T_PredictByFunction: {
|
||||
size_t len = strlen(((PredictByFunction*)node)->model_name) + strlen("_pred") + 1;
|
||||
char* colname = (char*)palloc0(len);
|
||||
errno_t rc = snprintf_s(colname, len, len - 1, "%s_pred", ((PredictByFunction*)node)->model_name);
|
||||
securec_check_ss(rc, "\0", "\0");
|
||||
*name = colname;
|
||||
return 1;
|
||||
} break;
|
||||
case T_TypeCast:
|
||||
strength = FigureColnameInternal(((TypeCast*)node)->arg, name);
|
||||
if (strength <= 1) {
|
||||
|
|
|
@ -92,6 +92,7 @@
|
|||
#include "utils/typcache.h"
|
||||
#include "utils/xml.h"
|
||||
#include "vecexecutor/vecnodes.h"
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
/* ----------
|
||||
* Pretty formatting constants
|
||||
|
@ -9234,6 +9235,11 @@ static void get_rule_expr(Node* node, deparse_context* context, bool showimplici
|
|||
}
|
||||
} break;
|
||||
|
||||
case T_GradientDescentExpr: {
|
||||
GradientDescentExpr* gdnode = (GradientDescentExpr*)node;
|
||||
appendStringInfo(buf, "GD(%s)", gd_get_expr_name(gdnode->field));
|
||||
} break;
|
||||
|
||||
default:
|
||||
if (context->qrw_phase)
|
||||
appendStringInfo(buf, "<unknown %d>", (int)nodeTag(node));
|
||||
|
|
|
@ -138,6 +138,7 @@
|
|||
#include "catalog/pg_streaming_stream.h"
|
||||
#include "catalog/pg_streaming_cont_query.h"
|
||||
#include "catalog/pg_streaming_reaper_status.h"
|
||||
#include "catalog/gs_model.h"
|
||||
#include "commands/matview.h"
|
||||
#include "commands/sec_rls_cmds.h"
|
||||
#include "commands/tablespace.h"
|
||||
|
@ -292,6 +293,7 @@ static const FormData_pg_attribute Desc_gs_client_global_keys[Natts_gs_client_gl
|
|||
static const FormData_pg_attribute Desc_gs_client_global_keys_args[Natts_gs_client_global_keys_args] = {Schema_gs_client_global_keys_args};
|
||||
|
||||
static const FormData_pg_attribute Desc_gs_opt_model[Natts_gs_opt_model] = {Schema_gs_opt_model};
|
||||
static const FormData_pg_attribute Desc_gs_model_warehouse[Natts_gs_model_warehouse] = {Schema_gs_model_warehouse};
|
||||
|
||||
/* Please add to the array in ascending order of oid value */
|
||||
static struct CatalogRelationBuildParam catalogBuildParam[CATALOG_NUM] = {{DefaultAclRelationId,
|
||||
|
@ -768,6 +770,15 @@ static struct CatalogRelationBuildParam catalogBuildParam[CATALOG_NUM] = {{Defau
|
|||
Desc_pg_ts_template,
|
||||
false,
|
||||
true},
|
||||
{ModelRelationId,
|
||||
"gs_model_warehouse",
|
||||
ModelRelation_Rowtype_Id,
|
||||
false,
|
||||
true,
|
||||
Natts_gs_model_warehouse,
|
||||
Desc_gs_model_warehouse,
|
||||
false,
|
||||
true},
|
||||
{DataSourceRelationId,
|
||||
"pg_extension_data_source",
|
||||
DataSourceRelation_Rowtype_Id,
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "access/sysattr.h"
|
||||
#include "catalog/gs_obsscaninfo.h"
|
||||
#include "catalog/gs_opt_model.h"
|
||||
#include "catalog/gs_model.h"
|
||||
#include "catalog/gs_policy_label.h"
|
||||
#include "catalog/indexing.h"
|
||||
#include "catalog/pg_aggregate.h"
|
||||
|
@ -299,6 +300,16 @@ static const struct cachedesc cacheinfo[] = {{AggregateRelationId, /* AGGFNOID *
|
|||
1,
|
||||
{ObjectIdAttributeNumber, 0, 0, 0},
|
||||
32},
|
||||
{ModelRelationId, /* DB4AI_MODELOID */
|
||||
GsModelOidIndexId,
|
||||
1,
|
||||
{ObjectIdAttributeNumber, 0, 0, 0},
|
||||
256},
|
||||
{ModelRelationId, /* DB4AI_MODEL */
|
||||
GsModelNameIndexId,
|
||||
1,
|
||||
{Anum_gs_model_model_name, 0, 0, 0},
|
||||
256},
|
||||
{DefaultAclRelationId, /* DEFACLROLENSPOBJ */
|
||||
DefaultAclRoleNspObjIndexId,
|
||||
3,
|
||||
|
|
|
@ -93,6 +93,7 @@ const module_data module_map[] = {{MOD_ALL, "ALL"},
|
|||
{MOD_THREAD_POOL, "THREAD_POOL"},
|
||||
{MOD_OPT_AI, "OPT_AI"},
|
||||
{MOD_GEN_COL, "GEN_COL"},
|
||||
{MOD_DB4AI, "DB4AI"},
|
||||
|
||||
/* add your module name above */
|
||||
{MOD_MAX, "BACKEND"}};
|
||||
|
|
|
@ -519,6 +519,7 @@ extern THR_LOCAL bool stmt_contains_operator_plus;
|
|||
EXTENSION EXTERNAL EXTRACT
|
||||
|
||||
FALSE_P FAMILY FAST FENCED FETCH FILEHEADER_P FILL_MISSING_FIELDS FILTER FIRST_P FIXED_P FLOAT_P FOLLOWING FOR FORCE FOREIGN FORMATTER FORWARD
|
||||
FEATURES // DB4AI
|
||||
FREEZE FROM FULL FUNCTION FUNCTIONS
|
||||
|
||||
GENERATED GLOBAL GLOBAL_FUNCTION GRANT GRANTED GREATEST GROUP_P GROUPING_P
|
||||
|
@ -538,6 +539,7 @@ extern THR_LOCAL bool stmt_contains_operator_plus;
|
|||
LEAST LESS LEFT LEVEL LIST LIKE LIMIT LISTEN LOAD LOCAL LOCALTIME LOCALTIMESTAMP
|
||||
LOCATION LOCK_P LOG_P LOGGING LOGIN_ANY LOGIN_SUCCESS LOGIN_FAILURE LOGOUT LOOP
|
||||
MAPPING MASKING MASTER MASTR MATCH MATERIALIZED MATCHED MAXEXTENTS MAXSIZE MAXTRANS MAXVALUE MERGE MINUS_P MINUTE_P MINVALUE MINEXTENTS MODE MODIFY_P MONTH_P MOVE MOVEMENT
|
||||
MODEL // DB4AI
|
||||
NAME_P NAMES NATIONAL NATURAL NCHAR NEXT NLSSORT NO NOCOMPRESS NOCYCLE NODE NOLOGGING NOMAXVALUE NOMINVALUE NONE
|
||||
NOT NOTHING NOTIFY NOTNULL NOWAIT NULL_P NULLIF NULLS_P NUMBER_P NUMERIC NUMSTR NVARCHAR2 NVL
|
||||
|
||||
|
@ -546,24 +548,28 @@ extern THR_LOCAL bool stmt_contains_operator_plus;
|
|||
|
||||
PACKAGE PARSER PARTIAL PARTITION PARTITIONS PASSING PASSWORD PCTFREE PER_P PERCENT PERFORMANCE PERM PLACING PLAN PLANS POLICY POSITION
|
||||
/* PGXC_BEGIN */
|
||||
POOL PRECEDING PRECISION PREFERRED PREFIX PRESERVE PREPARE PREPARED PRIMARY
|
||||
POOL PRECEDING PRECISION
|
||||
/* PGXC_END */
|
||||
PREDICT
|
||||
/* PGXC_BEGIN */
|
||||
PREFERRED PREFIX PRESERVE PREPARE PREPARED PRIMARY
|
||||
/* PGXC_END */
|
||||
PRIVATE PRIOR PRIVILEGES PRIVILEGE PROCEDURAL PROCEDURE PROFILE
|
||||
|
||||
QUERY QUOTE
|
||||
|
||||
RANDOMIZED RANGE RAW READ REAL REASSIGN REBUILD RECHECK RECURSIVE REDISANYVALUE REF REFERENCES REFRESH REINDEX REJECT_P
|
||||
RANDOMIZED RANGE RATIO RAW READ REAL REASSIGN REBUILD RECHECK RECURSIVE REDISANYVALUE REF REFERENCES REFRESH REINDEX REJECT_P
|
||||
RELATIVE_P RELEASE RELOPTIONS REMOTE_P REMOVE RENAME REPEATABLE REPLACE REPLICA
|
||||
RESET RESIZE RESOURCE RESTART RESTRICT RETURN RETURNING RETURNS REUSE REVOKE RIGHT ROLE ROLES ROLLBACK ROLLUP
|
||||
ROW ROWNUM ROWS RULE
|
||||
|
||||
SAVEPOINT SCHEMA SCROLL SEARCH SECOND_P SECURITY SELECT SEQUENCE SEQUENCES
|
||||
SAMPLE SAVEPOINT SCHEMA SCROLL SEARCH SECOND_P SECURITY SELECT SEQUENCE SEQUENCES
|
||||
SERIALIZABLE SERVER SESSION SESSION_USER SET SETS SETOF SHARE SHIPPABLE SHOW SHUTDOWN
|
||||
SIMILAR SIMPLE SIZE SLICE SMALLDATETIME SMALLDATETIME_FORMAT_P SMALLINT SNAPSHOT SOME SOURCE_P SPACE SPILL SPLIT STABLE STANDALONE_P START
|
||||
STATEMENT STATEMENT_ID STATISTICS STDIN STDOUT STORAGE STORE_P STORED STREAM STRICT_P STRIP_P SUBSTRING
|
||||
SYMMETRIC SYNONYM SYSDATE SYSID SYSTEM_P SYS_REFCURSOR
|
||||
|
||||
TABLE TABLES TABLESAMPLE TABLESPACE TEMP TEMPLATE TEMPORARY TEXT_P THAN THEN TIME TIME_FORMAT_P TIMESTAMP TIMESTAMP_FORMAT_P TIMESTAMPDIFF TINYINT
|
||||
TABLE TABLES TABLESAMPLE TABLESPACE TARGET TEMP TEMPLATE TEMPORARY TEXT_P THAN THEN TIME TIME_FORMAT_P TIMESTAMP TIMESTAMP_FORMAT_P TIMESTAMPDIFF TINYINT
|
||||
TO TRAILING TRANSACTION TREAT TRIGGER TRIM TRUE_P
|
||||
TRUNCATE TRUSTED TSFIELD TSTAG TSTIME TYPE_P TYPES_P
|
||||
|
||||
|
@ -10714,6 +10720,7 @@ unreserved_keyword:
|
|||
| EXTERNAL
|
||||
| FAMILY
|
||||
| FAST
|
||||
| FEATURES // DB4AI
|
||||
| FILEHEADER_P
|
||||
| FILL_MISSING_FIELDS
|
||||
| FIRST_P
|
||||
|
@ -10796,6 +10803,7 @@ unreserved_keyword:
|
|||
| MINUTE_P
|
||||
| MINVALUE
|
||||
| MODE
|
||||
| MODEL // DB4AI
|
||||
| MONTH_P
|
||||
| MOVE
|
||||
| MOVEMENT
|
||||
|
@ -10838,6 +10846,7 @@ unreserved_keyword:
|
|||
| PLANS
|
||||
| POOL
|
||||
| PRECEDING
|
||||
| PREDICT // DB4AI
|
||||
/* PGXC_BEGIN */
|
||||
| PREFERRED
|
||||
/* PGXC_END */
|
||||
|
@ -10931,6 +10940,7 @@ unreserved_keyword:
|
|||
| SYSTEM_P
|
||||
| TABLES
|
||||
| TABLESPACE
|
||||
| TARGET // DB4AI
|
||||
| TEMP
|
||||
| TEMPLATE
|
||||
| TEMPORARY
|
||||
|
|
|
@ -301,7 +301,7 @@ static bool last_pragma;
|
|||
* Some of these are not directly referenced in this file, but they must be
|
||||
* here anyway.
|
||||
*/
|
||||
%token <str> IDENT FCONST SCONST BCONST XCONST Op CmpOp COMMENTSTRING
|
||||
%token <str> IDENT FCONST SCONST BCONST VCONST XCONST Op CmpOp COMMENTSTRING
|
||||
%token <ival> ICONST PARAM
|
||||
%token TYPECAST ORA_JOINOP DOT_DOT COLON_EQUALS PARA_EQUALS
|
||||
|
||||
|
|
|
@ -9,6 +9,6 @@ subdir = src/gausskernel/dbmind
|
|||
top_builddir = ../../..
|
||||
include $(top_builddir)/src/Makefile.global
|
||||
|
||||
SUBDIRS = kernel
|
||||
SUBDIRS = kernel db4ai
|
||||
|
||||
include $(top_srcdir)/src/gausskernel/common.mk
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
#This is the main CMAKE for build bin.
|
||||
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON)
|
||||
set(CMAKE_RULE_MESSAGES OFF)
|
||||
set(CMAKE_SKIP_RPATH TRUE)
|
||||
|
||||
set(CMAKE_MODULE_PATH
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/catalog
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/commands
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/executor
|
||||
)
|
||||
add_subdirectory(catalog)
|
||||
add_subdirectory(commands)
|
||||
add_subdirectory(executor)
|
||||
|
||||
if("${ENABLE_MULTIPLE_NODES}" STREQUAL "OFF")
|
||||
install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/snapshots
|
||||
DESTINATION share/postgresql/db4ai/ OPTIONAL
|
||||
)
|
||||
endif()
|
|
@ -0,0 +1,24 @@
|
|||
#---------------------------------------------------------------------------------------
|
||||
#
|
||||
# IDENTIFICATION
|
||||
# src/gausskernel/dbmind/db4ai/Makefile
|
||||
#
|
||||
# ---------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
subdir = src/gausskernel/dbmind/db4ai
|
||||
top_builddir = ../../../..
|
||||
|
||||
include $(top_builddir)/src/Makefile.global
|
||||
|
||||
ifneq "$(MAKECMDGOALS)" "clean"
|
||||
ifneq "$(MAKECMDGOALS)" "distclean"
|
||||
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
|
||||
-include $(DEPEND)
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
SUBDIRS = catalog commands executor
|
||||
|
||||
include $(top_srcdir)/src/gausskernel/common.mk
|
|
@ -0,0 +1,24 @@
|
|||
#This is the main CMAKE for build all components.
|
||||
|
||||
AUX_SOURCE_DIRECTORY(${CMAKE_CURRENT_SOURCE_DIR} TGT_catalog_SRC)
|
||||
|
||||
set(TGT_catalog_INC
|
||||
${PROJECT_OPENGS_DIR}/contrib/log_fdw
|
||||
${PROJECT_TRUNK_DIR}/distribute/bin/gds
|
||||
${PROJECT_SRC_DIR}/include/libcomm
|
||||
${PROJECT_SRC_DIR}/include
|
||||
${PROJECT_SRC_DIR}/lib/gstrace
|
||||
${LZ4_INCLUDE_PATH}
|
||||
${LIBCGROUP_INCLUDE_PATH}
|
||||
${LIBORC_INCLUDE_PATH}
|
||||
${EVENT_INCLUDE_PATH}
|
||||
${PROTOBUF_INCLUDE_PATH}
|
||||
${ZLIB_INCLUDE_PATH}
|
||||
)
|
||||
|
||||
set(catalog_DEF_OPTIONS ${MACRO_OPTIONS})
|
||||
set(catalog_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
|
||||
list(REMOVE_ITEM catalog_COMPILE_OPTIONS -fPIC)
|
||||
set(catalog_COMPILE_OPTIONS ${catalog_COMPILE_OPTIONS} -std=c++14 -fPIE)
|
||||
set(catalog_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
|
||||
add_static_objtarget(gausskernel_db4ai_catalog TGT_catalog_SRC TGT_catalog_INC "${catalog_DEF_OPTIONS}" "${catalog_COMPILE_OPTIONS}" "${catalog_LINK_OPTIONS}")
|
|
@ -0,0 +1,22 @@
|
|||
#---------------------------------------------------------------------------------------
|
||||
#
|
||||
# IDENTIFICATION
|
||||
# src/gausskernel/dbmind/executor/Makefile
|
||||
#
|
||||
# ---------------------------------------------------------------------------------------
|
||||
|
||||
subdir = src/gausskernel/dbmind/db4ai/catalog
|
||||
top_builddir = ../../../../..
|
||||
include $(top_builddir)/src/Makefile.global
|
||||
|
||||
ifneq "$(MAKECMDGOALS)" "clean"
|
||||
ifneq "$(MAKECMDGOALS)" "distclean"
|
||||
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
|
||||
-include $(DEPEND)
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
OBJS = model_warehouse.o
|
||||
|
||||
include $(top_srcdir)/src/gausskernel/common.mk
|
|
@ -0,0 +1,760 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* command.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/catalog/model_warehouse.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/model_warehouse.h"
|
||||
#include "db4ai/gd.h"
|
||||
#include "db4ai/aifuncs.h"
|
||||
#include "access/tableam.h"
|
||||
#include "catalog/gs_model.h"
|
||||
#include "catalog/indexing.h"
|
||||
#include "catalog/pg_proc.h"
|
||||
#include "instruments/generate_report.h"
|
||||
#include "lib/stringinfo.h"
|
||||
#include "utils/fmgroids.h"
|
||||
#include "utils/builtins.h"
|
||||
#include "utils/lsyscache.h"
|
||||
|
||||
typedef enum ListType {
|
||||
HYPERPARAMETERS = 0,
|
||||
COEFS,
|
||||
SCORES,
|
||||
} ListType;
|
||||
|
||||
|
||||
template <ListType ltype> void ListToTuple(List *list, Datum *name, Datum *value, Datum *oid);
|
||||
|
||||
template <ListType listType> void TupleToList(Model *model, Datum *names, Datum *values, Datum *oids);
|
||||
|
||||
template <ListType ltype> static void add_model_parameter(Model *model, const char *name, Oid type, Datum value);
|
||||
|
||||
static Datum string_to_datum(const char *str, Oid datatype);
|
||||
|
||||
void store_SGD(Datum *values, bool *nulls, ModelGradientDescent *SGDmodel);
|
||||
void get_SGD(HeapTuple *tuple, ModelGradientDescent *resGD, Form_gs_model_warehouse tuplePointer);
|
||||
|
||||
void store_kmeans(Datum *values, bool *nulls, ModelKMeans *kmeansModel);
|
||||
void get_kmeans(HeapTuple *tuple, ModelKMeans *modelKmeans);
|
||||
void splitStringFillCentroid(WHCentroid *curseCent, char *strDescribe);
|
||||
char *splitStringFillCoordinates(WHCentroid *curseCent, char *strCoordinates, int dimension);
|
||||
|
||||
// Store the model in the catalog tables
|
||||
void store_model(const Model *model)
|
||||
{
|
||||
HeapTuple tuple;
|
||||
int rc;
|
||||
Relation rel = NULL;
|
||||
Oid extOwner = GetUserId();
|
||||
Datum values[Natts_gs_model_warehouse];
|
||||
bool nulls[Natts_gs_model_warehouse];
|
||||
Datum ListNames, ListValues, ListOids;
|
||||
|
||||
if (SearchSysCacheExists1(DB4AI_MODEL, CStringGetDatum(model->model_name))) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("The model name \"%s\" already exists in gs_model_warehouse.", model->model_name)));
|
||||
}
|
||||
|
||||
rel = heap_open(ModelRelationId, RowExclusiveLock);
|
||||
|
||||
rc = memset_s(values, sizeof(values), 0, sizeof(values));
|
||||
securec_check(rc, "\0", "\0");
|
||||
rc = memset_s(nulls, sizeof(nulls), 0, sizeof(nulls));
|
||||
securec_check(rc, "\0", "\0");
|
||||
|
||||
values[Anum_gs_model_model_name - 1] = DirectFunctionCall1(namein, CStringGetDatum(model->model_name));
|
||||
values[Anum_gs_model_owner_oid - 1] = ObjectIdGetDatum(extOwner);
|
||||
values[Anum_gs_model_create_time - 1] = DirectFunctionCall1(timestamptz_timestamp, GetCurrentTimestamp());
|
||||
values[Anum_gs_model_processedTuples - 1] = Int64GetDatum(model->processed_tuples);
|
||||
values[Anum_gs_model_discardedTuples - 1] = Int64GetDatum(model->discarded_tuples);
|
||||
values[Anum_gs_model_process_time_secs - 1] = Float4GetDatum(model->pre_time_secs);
|
||||
values[Anum_gs_model_exec_time_secs - 1] = Float4GetDatum(model->exec_time_secs);
|
||||
values[Anum_gs_model_iterations - 1] = Int64GetDatum(model->num_actual_iterations);
|
||||
values[Anum_gs_model_outputType - 1] = ObjectIdGetDatum(model->return_type);
|
||||
values[Anum_gs_model_query - 1] = CStringGetTextDatum(model->sql);
|
||||
|
||||
if (model->hyperparameters == nullptr) {
|
||||
nulls[Anum_gs_model_hyperparametersNames - 1] = true;
|
||||
nulls[Anum_gs_model_hyperparametersValues - 1] = true;
|
||||
nulls[Anum_gs_model_hyperparametersOids - 1] = true;
|
||||
} else {
|
||||
ListToTuple<ListType::HYPERPARAMETERS>(model->hyperparameters, &ListNames, &ListValues, &ListOids);
|
||||
values[Anum_gs_model_hyperparametersNames - 1] = ListNames;
|
||||
values[Anum_gs_model_hyperparametersValues - 1] = ListValues;
|
||||
values[Anum_gs_model_hyperparametersOids - 1] = ListOids;
|
||||
}
|
||||
|
||||
if (model->train_info == nullptr) {
|
||||
nulls[Anum_gs_model_coefNames - 1] = true;
|
||||
nulls[Anum_gs_model_coefValues - 1] = true;
|
||||
nulls[Anum_gs_model_coefOids - 1] = true;
|
||||
} else {
|
||||
ListToTuple<ListType::COEFS>(model->train_info, &ListNames, &ListValues, &ListOids);
|
||||
values[Anum_gs_model_coefNames - 1] = ListNames;
|
||||
values[Anum_gs_model_coefValues - 1] = ListValues;
|
||||
values[Anum_gs_model_coefOids - 1] = ListOids;
|
||||
}
|
||||
|
||||
if (model->scores == nullptr) {
|
||||
nulls[Anum_gs_model_trainingScoresName - 1] = true;
|
||||
nulls[Anum_gs_model_trainingScoresValue - 1] = true;
|
||||
} else {
|
||||
ListToTuple<ListType::SCORES>(model->scores, &ListNames, &ListValues, &ListOids);
|
||||
values[Anum_gs_model_trainingScoresName - 1] = ListNames;
|
||||
values[Anum_gs_model_trainingScoresValue - 1] = ListValues;
|
||||
}
|
||||
|
||||
switch (model->algorithm) {
|
||||
case LOGISTIC_REGRESSION:
|
||||
case SVM_CLASSIFICATION:
|
||||
case LINEAR_REGRESSION: {
|
||||
store_SGD(values, nulls, (ModelGradientDescent *)model);
|
||||
} break;
|
||||
case KMEANS: {
|
||||
store_kmeans(values, nulls, (ModelKMeans *)model);
|
||||
} break;
|
||||
default:
|
||||
// do not cache
|
||||
ereport(NOTICE, (errmsg("clone model for type %d", (int)model->algorithm)));
|
||||
break;
|
||||
}
|
||||
|
||||
tuple = heap_form_tuple(rel->rd_att, values, nulls);
|
||||
(void)simple_heap_insert(rel, tuple);
|
||||
CatalogUpdateIndexes(rel, tuple);
|
||||
heap_freetuple_ext(tuple);
|
||||
heap_close(rel, RowExclusiveLock);
|
||||
}
|
||||
|
||||
// Get the model from the catalog tables
|
||||
Model *get_model(const char *model_name, bool only_model)
|
||||
{
|
||||
void *result = NULL;
|
||||
Model *model = NULL;
|
||||
Datum ListNames, ListValues, ListOids;
|
||||
bool isnull = false;
|
||||
bool isnullValue = false;
|
||||
bool isnullOid = false;
|
||||
AlgorithmML algorithm;
|
||||
|
||||
if (t_thrd.proc->workingVersionNum < 92304) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Before GRAND VERSION NUM 92304, we do not support gs_model_warehouse.")));
|
||||
}
|
||||
|
||||
HeapTuple tuple = SearchSysCache1(DB4AI_MODEL, CStringGetDatum(model_name));
|
||||
if (!HeapTupleIsValid(tuple)) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("There is no model called \"%s\".", model_name)));
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Form_gs_model_warehouse tuplePointer = (Form_gs_model_warehouse)GETSTRUCT(tuple);
|
||||
const char *modelType = TextDatumGetCString(SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_model_type, &isnull));
|
||||
|
||||
algorithm = get_algorithm_ml(modelType);
|
||||
if (algorithm == INVALID_ALGORITHM_ML) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("the type of model is invalid: %s", modelType)));
|
||||
}
|
||||
|
||||
if (only_model) {
|
||||
result = palloc0(sizeof(Model));
|
||||
model = (Model *)result;
|
||||
} else {
|
||||
switch (algorithm) {
|
||||
case LOGISTIC_REGRESSION:
|
||||
case SVM_CLASSIFICATION:
|
||||
case LINEAR_REGRESSION: {
|
||||
result = palloc0(sizeof(ModelGradientDescent));
|
||||
ModelGradientDescent *resGD = (ModelGradientDescent *)result;
|
||||
get_SGD(&tuple, resGD, tuplePointer);
|
||||
model = &(resGD->model);
|
||||
} break;
|
||||
case KMEANS: {
|
||||
result = palloc0(sizeof(ModelKMeans));
|
||||
ModelKMeans *resKmeans = (ModelKMeans *)result;
|
||||
get_kmeans(&tuple, resKmeans);
|
||||
model = &(resKmeans->model);
|
||||
} break;
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("the type of model is invalid: %s", modelType)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
model->algorithm = algorithm;
|
||||
model->model_name = model_name;
|
||||
model->exec_time_secs = tuplePointer->exec_time;
|
||||
model->pre_time_secs = tuplePointer->pre_process_time;
|
||||
model->processed_tuples = tuplePointer->processedtuples;
|
||||
model->discarded_tuples = tuplePointer->discardedtuples;
|
||||
model->return_type = tuplePointer->outputtype;
|
||||
model->num_actual_iterations = tuplePointer->iterations;
|
||||
model->sql = TextDatumGetCString(SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_query, &isnull));
|
||||
|
||||
ListNames = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_hyperparametersNames, &isnull);
|
||||
ListValues = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_hyperparametersValues, &isnullValue);
|
||||
ListOids = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_hyperparametersOids, &isnullOid);
|
||||
if (!isnull && !isnullValue && !isnullOid) {
|
||||
TupleToList<ListType::HYPERPARAMETERS>(model, &ListNames, &ListValues, &ListOids);
|
||||
}
|
||||
|
||||
ListNames = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_coefNames, &isnull);
|
||||
ListValues = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_coefValues, &isnullValue);
|
||||
ListOids = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_coefOids, &isnullOid);
|
||||
if (!isnull && !isnullValue && !isnullOid) {
|
||||
TupleToList<ListType::COEFS>(model, &ListNames, &ListValues, &ListOids);
|
||||
}
|
||||
|
||||
ListNames = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_trainingScoresName, &isnull);
|
||||
ListValues = SysCacheGetAttr(DB4AI_MODEL, tuple, Anum_gs_model_trainingScoresValue, &isnullValue);
|
||||
if (!isnull && !isnullValue) {
|
||||
TupleToList<ListType::SCORES>(model, &ListNames, &ListValues, NULL);
|
||||
}
|
||||
|
||||
ReleaseSysCache(tuple);
|
||||
return (Model *)result;
|
||||
}
|
||||
|
||||
void elog_model(int level, const Model *model)
|
||||
{
|
||||
Oid typoutput;
|
||||
bool typIsVarlena;
|
||||
ListCell *lc;
|
||||
StringInfoData buf;
|
||||
initStringInfo(&buf);
|
||||
const char* model_type = algorithm_ml_to_string(model->algorithm);
|
||||
|
||||
appendStringInfo(&buf, "\n:type %s", model_type);
|
||||
appendStringInfo(&buf, "\n:sql %s", model->sql);
|
||||
if (model->hyperparameters != nullptr) {
|
||||
appendStringInfoString(&buf, "\n:hyperparameters");
|
||||
foreach (lc, model->hyperparameters) {
|
||||
Hyperparameter *hyperp = lfirst_node(Hyperparameter, lc);
|
||||
getTypeOutputInfo(hyperp->type, &typoutput, &typIsVarlena);
|
||||
appendStringInfo(&buf, "\n :%s %s", hyperp->name, OidOutputFunctionCall(typoutput, hyperp->value));
|
||||
}
|
||||
}
|
||||
appendStringInfo(&buf, "\n:return type %u", model->return_type);
|
||||
appendStringInfo(&buf, "\n:pre-processing time %lf s", model->pre_time_secs);
|
||||
appendStringInfo(&buf, "\n:exec time %lf s", model->exec_time_secs);
|
||||
appendStringInfo(&buf, "\n:processed %ld tuples", model->processed_tuples);
|
||||
appendStringInfo(&buf, "\n:discarded %ld tuples", model->discarded_tuples);
|
||||
appendStringInfo(&buf, "\n:actual number of iterations %d", model->num_actual_iterations);
|
||||
if (model->train_info != nullptr) {
|
||||
appendStringInfoString(&buf, "\n:info");
|
||||
foreach (lc, model->train_info) {
|
||||
TrainingInfo *info = lfirst_node(TrainingInfo, lc);
|
||||
getTypeOutputInfo(info->type, &typoutput, &typIsVarlena);
|
||||
appendStringInfo(&buf, "\n :%s %s", info->name, OidOutputFunctionCall(typoutput, info->value));
|
||||
}
|
||||
}
|
||||
if (model->scores != nullptr) {
|
||||
appendStringInfoString(&buf, "\n:scores");
|
||||
foreach (lc, model->scores) {
|
||||
TrainingScore *score = lfirst_node(TrainingScore, lc);
|
||||
appendStringInfo(&buf, "\n :%s %.16g", score->name, score->value);
|
||||
}
|
||||
}
|
||||
if (model->algorithm == LOGISTIC_REGRESSION ||
|
||||
model->algorithm == SVM_CLASSIFICATION ||
|
||||
model->algorithm == LINEAR_REGRESSION) {
|
||||
ModelGradientDescent *model_gd = (ModelGradientDescent *)model;
|
||||
appendStringInfoString(&buf, "\n:gradient_descent:");
|
||||
appendStringInfo(&buf, "\n :algorithm %s", gd_get_algorithm(model_gd->model.algorithm)->name);
|
||||
getTypeOutputInfo(FLOAT4ARRAYOID, &typoutput, &typIsVarlena);
|
||||
appendStringInfo(&buf, "\n :weights %s", OidOutputFunctionCall(typoutput, model_gd->weights));
|
||||
if (model_gd->ncategories > 0) {
|
||||
Datum dt;
|
||||
bool isnull;
|
||||
bool first = true;
|
||||
struct varlena *src_arr = (struct varlena *)DatumGetPointer(model_gd->categories);
|
||||
ArrayType *arr = (ArrayType *)pg_detoast_datum(src_arr);
|
||||
Assert(arr->elemtype == model->return_type);
|
||||
ArrayIterator it = array_create_iterator(arr, 0);
|
||||
|
||||
getTypeOutputInfo(model->return_type, &typoutput, &typIsVarlena);
|
||||
appendStringInfo(&buf, "\n :categories %d {", model_gd->ncategories);
|
||||
while (array_iterate(it, &dt, &isnull)) {
|
||||
Assert(!isnull);
|
||||
appendStringInfo(&buf, "%s%s", first ? "" : ",", OidOutputFunctionCall(typoutput, dt));
|
||||
first = false;
|
||||
}
|
||||
appendStringInfoString(&buf, "}");
|
||||
|
||||
array_free_iterator(it);
|
||||
if (arr != (ArrayType *)src_arr)
|
||||
pfree(arr);
|
||||
}
|
||||
}
|
||||
elog(level, "Model=%s%s", model->model_name, buf.data);
|
||||
pfree(buf.data);
|
||||
}
|
||||
|
||||
template <ListType ltype> void ListToTuple(List *list, Datum *name, Datum *value, Datum *oid)
|
||||
{
|
||||
text *t_names, *t_values;
|
||||
Datum *array_container = nullptr;
|
||||
int iter = 0;
|
||||
ArrayBuildState *astateName = NULL, *astateValue = NULL;
|
||||
|
||||
array_container = (Datum *)palloc0(list->length * sizeof(Datum));
|
||||
foreach_cell(it, list)
|
||||
{
|
||||
switch (ltype) {
|
||||
case ListType::HYPERPARAMETERS: {
|
||||
Hyperparameter *cell = (Hyperparameter *)lfirst(it);
|
||||
t_names = cstring_to_text(cell->name);
|
||||
t_values = cstring_to_text(Datum_to_string(cell->value, cell->type, false));
|
||||
array_container[iter] = ObjectIdGetDatum(cell->type);
|
||||
astateName =
|
||||
accumArrayResult(astateName, PointerGetDatum(t_names), false, TEXTOID, CurrentMemoryContext);
|
||||
astateValue =
|
||||
accumArrayResult(astateValue, PointerGetDatum(t_values), false, TEXTOID, CurrentMemoryContext);
|
||||
iter++;
|
||||
} break;
|
||||
case ListType::COEFS: {
|
||||
TrainingInfo *cell = (TrainingInfo *)lfirst(it);
|
||||
t_names = cstring_to_text(cell->name);
|
||||
t_values = cstring_to_text(Datum_to_string(cell->value, cell->type, false));
|
||||
array_container[iter] = ObjectIdGetDatum(cell->type);
|
||||
astateName =
|
||||
accumArrayResult(astateName, PointerGetDatum(t_names), false, TEXTOID, CurrentMemoryContext);
|
||||
astateValue =
|
||||
accumArrayResult(astateValue, PointerGetDatum(t_values), false, TEXTOID, CurrentMemoryContext);
|
||||
iter++;
|
||||
} break;
|
||||
case ListType::SCORES: {
|
||||
TrainingScore *cell = (TrainingScore *)lfirst(it);
|
||||
t_names = cstring_to_text(cell->name);
|
||||
array_container[iter] = Float4GetDatum(cell->value);
|
||||
astateName =
|
||||
accumArrayResult(astateName, PointerGetDatum(t_names), false, TEXTOID, CurrentMemoryContext);
|
||||
iter++;
|
||||
} break;
|
||||
default: {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Not support saving this data into model warehouse.")));
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
switch (ltype) {
|
||||
case ListType::HYPERPARAMETERS:
|
||||
case ListType::COEFS: {
|
||||
*name = makeArrayResult(astateName, CurrentMemoryContext);
|
||||
*value = makeArrayResult(astateValue, CurrentMemoryContext);
|
||||
ArrayType *oid_array = construct_array(array_container, list->length, OIDOID, sizeof(Oid), true, 'i');
|
||||
*oid = PointerGetDatum(oid_array);
|
||||
} break;
|
||||
case ListType::SCORES: {
|
||||
*name = makeArrayResult(astateName, CurrentMemoryContext);
|
||||
ArrayType *value_array =
|
||||
construct_array(array_container, list->length, FLOAT4OID, sizeof(float4), FLOAT4PASSBYVAL, 'i');
|
||||
*value = PointerGetDatum(value_array);
|
||||
} break;
|
||||
default: {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Not support saving this data into model warehouse.")));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <ListType listType> void TupleToList(Model *model, Datum *names, Datum *values, Datum *oids)
|
||||
{
|
||||
char *strNames, *strValues;
|
||||
Oid tranOids;
|
||||
Datum dtNames, dtValues, dtOid;
|
||||
ArrayType *arrNames, *arrValues, *arrOids;
|
||||
bool isnull;
|
||||
ArrayIterator itOid = NULL;
|
||||
|
||||
arrNames = DatumGetArrayTypeP(*names);
|
||||
arrValues = DatumGetArrayTypeP(*values);
|
||||
|
||||
ArrayIterator itName = array_create_iterator(arrNames, 0);
|
||||
ArrayIterator itValue = array_create_iterator(arrValues, 0);
|
||||
if (oids != NULL) {
|
||||
arrOids = DatumGetArrayTypeP(*oids);
|
||||
itOid = array_create_iterator(arrOids, 0);
|
||||
}
|
||||
while (array_iterate(itName, &dtNames, &isnull)) {
|
||||
array_iterate(itValue, &dtValues, &isnull);
|
||||
switch (listType) {
|
||||
case ListType::HYPERPARAMETERS: {
|
||||
array_iterate(itOid, &dtOid, &isnull);
|
||||
strNames = TextDatumGetCString(dtNames);
|
||||
strValues = TextDatumGetCString(dtValues);
|
||||
tranOids = DatumGetObjectId(dtOid);
|
||||
dtValues = string_to_datum(strValues, tranOids);
|
||||
add_model_parameter<listType>(model, strNames, tranOids, dtValues);
|
||||
} break;
|
||||
case ListType::COEFS: {
|
||||
array_iterate(itOid, &dtOid, &isnull);
|
||||
strNames = TextDatumGetCString(dtNames);
|
||||
strValues = TextDatumGetCString(dtValues);
|
||||
tranOids = DatumGetObjectId(dtOid);
|
||||
dtValues = string_to_datum(strValues, tranOids);
|
||||
add_model_parameter<listType>(model, strNames, tranOids, dtValues);
|
||||
} break;
|
||||
case ListType::SCORES: {
|
||||
strNames = TextDatumGetCString(dtNames);
|
||||
add_model_parameter<listType>(model, strNames, 0, dtValues);
|
||||
} break;
|
||||
default: {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Not support fetching this data from model warehouse.")));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <ListType ltype> static void add_model_parameter(Model *model, const char *name, Oid type, Datum value)
|
||||
{
|
||||
switch (ltype) {
|
||||
case ListType::HYPERPARAMETERS: {
|
||||
Hyperparameter *hyperp = (Hyperparameter *)palloc0(sizeof(Hyperparameter));
|
||||
hyperp->name = pstrdup(name);
|
||||
hyperp->type = type;
|
||||
hyperp->value = value;
|
||||
model->hyperparameters = lappend(model->hyperparameters, hyperp);
|
||||
} break;
|
||||
case ListType::COEFS: {
|
||||
TrainingInfo *tinfo = (TrainingInfo *)palloc0(sizeof(TrainingInfo));
|
||||
tinfo->name = pstrdup(name);
|
||||
tinfo->type = type;
|
||||
tinfo->value = value;
|
||||
model->train_info = lappend(model->train_info, tinfo);
|
||||
} break;
|
||||
case ListType::SCORES: {
|
||||
TrainingScore *tscore = (TrainingScore *)palloc0(sizeof(TrainingScore));
|
||||
tscore->name = pstrdup(name);
|
||||
tscore->value = DatumGetFloat4(value);
|
||||
model->scores = lappend(model->scores, tscore);
|
||||
} break;
|
||||
default: {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Not support put this data into Model-struct.")));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static Datum string_to_datum(const char *str, Oid datatype)
|
||||
{
|
||||
switch (datatype) {
|
||||
case BOOLOID:
|
||||
return DirectFunctionCall1(boolin, CStringGetDatum(str));
|
||||
case INT1OID:
|
||||
case INT2OID:
|
||||
case INT4OID:
|
||||
return Int32GetDatum(atoi(str));
|
||||
case INT8OID:
|
||||
return Int64GetDatum(atoi(str));
|
||||
case VARCHAROID:
|
||||
case BPCHAROID:
|
||||
case CHAROID:
|
||||
case TEXTOID:
|
||||
return CStringGetTextDatum(str);
|
||||
case FLOAT4OID:
|
||||
case FLOAT8OID:
|
||||
return DirectFunctionCall1(float8in, CStringGetDatum(str));
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("The type is not supported: %d", datatype)));
|
||||
return CStringGetTextDatum(str);
|
||||
}
|
||||
}
|
||||
|
||||
/* store SGD model */
|
||||
void store_SGD(Datum *values, bool *nulls, ModelGradientDescent *SGDmodel)
|
||||
{
|
||||
nulls[Anum_gs_model_modelData - 1] = true;
|
||||
nulls[Anum_gs_model_modeldescribe - 1] = true;
|
||||
|
||||
values[Anum_gs_model_model_type - 1] = CStringGetTextDatum(algorithm_ml_to_string(SGDmodel->model.algorithm));
|
||||
values[Anum_gs_model_weight - 1] = SGDmodel->weights;
|
||||
if (SGDmodel->ncategories > 0) {
|
||||
text *categoriesName, *categoriesValue;
|
||||
ArrayBuildState *astate = NULL;
|
||||
Datum dt;
|
||||
bool isnull;
|
||||
|
||||
nulls[Anum_gs_model_coefNames - 1] = false;
|
||||
categoriesName = cstring_to_text("categories");
|
||||
astate = accumArrayResult(astate, PointerGetDatum(categoriesName), false, TEXTOID, CurrentMemoryContext);
|
||||
values[Anum_gs_model_coefNames - 1] = makeArrayResult(astate, CurrentMemoryContext);
|
||||
astate = NULL;
|
||||
|
||||
ArrayType *arr = (ArrayType *)pg_detoast_datum((struct varlena *)DatumGetPointer(SGDmodel->categories));
|
||||
ArrayIterator it = array_create_iterator(arr, 0);
|
||||
while (array_iterate(it, &dt, &isnull)) {
|
||||
categoriesValue = cstring_to_text(Datum_to_string(dt, SGDmodel->model.return_type, false));
|
||||
astate = accumArrayResult(astate, PointerGetDatum(categoriesValue), false, TEXTOID, CurrentMemoryContext);
|
||||
}
|
||||
values[Anum_gs_model_coefValues - 1] = makeArrayResult(astate, CurrentMemoryContext);
|
||||
nulls[Anum_gs_model_coefValues - 1] = false;
|
||||
}
|
||||
}
|
||||
|
||||
/* get SGD model */
|
||||
void get_SGD(HeapTuple *tuple, ModelGradientDescent *resGD, Form_gs_model_warehouse tuplePointer)
|
||||
{
|
||||
char *strValues;
|
||||
Datum dtValues;
|
||||
ArrayBuildState *astate = NULL;
|
||||
bool isnull = false;
|
||||
|
||||
/* weight */
|
||||
resGD->weights = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_weight, &isnull);
|
||||
|
||||
/* categories */
|
||||
resGD->ncategories = 0;
|
||||
Datum dtCat = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_coefValues, &isnull);
|
||||
|
||||
if (!isnull) {
|
||||
ArrayType *arrValues = DatumGetArrayTypeP(dtCat);
|
||||
ArrayIterator itValue = array_create_iterator(arrValues, 0);
|
||||
while (array_iterate(itValue, &dtValues, &isnull)) {
|
||||
resGD->ncategories++;
|
||||
strValues = TextDatumGetCString(dtValues);
|
||||
dtValues = string_to_datum(strValues, tuplePointer->outputtype);
|
||||
astate = accumArrayResult(astate, dtValues, false, tuplePointer->outputtype, CurrentMemoryContext);
|
||||
}
|
||||
resGD->categories = makeArrayResult(astate, CurrentMemoryContext);
|
||||
} else {
|
||||
resGD->categories = PointerGetDatum(NULL);
|
||||
}
|
||||
}
|
||||
|
||||
/* get kmeans model */
|
||||
void store_kmeans(Datum *values, bool *nulls, ModelKMeans *kmeansModel)
|
||||
{
|
||||
ArrayBuildState *astateName = NULL, *astateValue = NULL, *astateDescribe = NULL;
|
||||
text *tValue, *txDescribe, *txCoodinate;
|
||||
int lenDescribe = 200 * sizeof(char), lengthD = 0, lengthC = 0;
|
||||
int lenCoodinate = 15 * kmeansModel->dimension * kmeansModel->actual_num_centroids;
|
||||
WHCentroid *centroid;
|
||||
double *coordinateContainer;
|
||||
char *describeElem, *strCoordinates = (char *)palloc0(lenCoodinate);
|
||||
|
||||
values[Anum_gs_model_outputType - 1] = ObjectIdGetDatum(kmeansModel->model.return_type);
|
||||
nulls[Anum_gs_model_modelData - 1] = true;
|
||||
nulls[Anum_gs_model_weight - 1] = true;
|
||||
values[Anum_gs_model_model_type - 1] = CStringGetTextDatum(algorithm_ml_to_string(KMEANS));
|
||||
|
||||
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->original_num_centroids), INT8OID, false));
|
||||
astateName = accumArrayResult(astateName, CStringGetTextDatum("original_num_centroids"), false, TEXTOID,
|
||||
CurrentMemoryContext);
|
||||
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
|
||||
|
||||
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->actual_num_centroids), INT8OID, false));
|
||||
astateName =
|
||||
accumArrayResult(astateName, CStringGetTextDatum("actual_num_centroids"), false, TEXTOID, CurrentMemoryContext);
|
||||
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
|
||||
|
||||
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->dimension), INT8OID, false));
|
||||
astateName = accumArrayResult(astateName, CStringGetTextDatum("dimension"), false, TEXTOID, CurrentMemoryContext);
|
||||
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
|
||||
|
||||
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->distance_function_id), INT8OID, false));
|
||||
astateName =
|
||||
accumArrayResult(astateName, CStringGetTextDatum("distance_function_id"), false, TEXTOID, CurrentMemoryContext);
|
||||
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
|
||||
|
||||
tValue = cstring_to_text(Datum_to_string(Int64GetDatum(kmeansModel->seed), INT8OID, false));
|
||||
astateName = accumArrayResult(astateName, CStringGetTextDatum("seed"), false, TEXTOID, CurrentMemoryContext);
|
||||
astateValue = accumArrayResult(astateValue, PointerGetDatum(tValue), false, TEXTOID, CurrentMemoryContext);
|
||||
|
||||
astateName = accumArrayResult(astateName, CStringGetTextDatum("coordinates"), false, TEXTOID, CurrentMemoryContext);
|
||||
|
||||
for (uint32_t i = 0; i < kmeansModel->actual_num_centroids; i++) {
|
||||
lengthD = 0;
|
||||
describeElem = (char *)palloc0(lenDescribe);
|
||||
centroid = kmeansModel->centroids + i;
|
||||
|
||||
lengthD = sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "id:%d,", centroid->id);
|
||||
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "objective_function:%f,",
|
||||
centroid->objective_function);
|
||||
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "avg_distance_to_centroid:%f,",
|
||||
centroid->avg_distance_to_centroid);
|
||||
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "min_distance_to_centroid:%f,",
|
||||
centroid->min_distance_to_centroid);
|
||||
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "max_distance_to_centroid:%f,",
|
||||
centroid->max_distance_to_centroid);
|
||||
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "std_dev_distance_to_centroid:%f,",
|
||||
centroid->std_dev_distance_to_centroid);
|
||||
lengthD += sprintf_s(describeElem + lengthD, lenDescribe - lengthD, "cluster_size:%d", centroid->cluster_size);
|
||||
|
||||
txDescribe = cstring_to_text(describeElem);
|
||||
astateDescribe =
|
||||
accumArrayResult(astateDescribe, PointerGetDatum(txDescribe), false, TEXTOID, CurrentMemoryContext);
|
||||
|
||||
coordinateContainer = centroid->coordinates;
|
||||
lengthC += sprintf_s(strCoordinates + lengthC, lenCoodinate - lengthC, "(");
|
||||
for (uint32_t j = 0; j < kmeansModel->dimension; j++) {
|
||||
lengthC += sprintf_s(strCoordinates + lengthC, lenCoodinate - lengthC, "%f,", coordinateContainer[j]);
|
||||
}
|
||||
lengthC--;
|
||||
lengthC += sprintf_s(strCoordinates + lengthC, lenCoodinate - lengthC, ")");
|
||||
}
|
||||
txCoodinate = cstring_to_text(strCoordinates);
|
||||
astateValue = accumArrayResult(astateValue, PointerGetDatum(txCoodinate), false, TEXTOID, CurrentMemoryContext);
|
||||
|
||||
values[Anum_gs_model_modeldescribe - 1] = makeArrayResult(astateDescribe, CurrentMemoryContext);
|
||||
values[Anum_gs_model_coefValues - 1] = makeArrayResult(astateValue, CurrentMemoryContext);
|
||||
values[Anum_gs_model_coefNames - 1] = makeArrayResult(astateName, CurrentMemoryContext);
|
||||
|
||||
nulls[Anum_gs_model_coefValues - 1] = false;
|
||||
nulls[Anum_gs_model_coefNames - 1] = false;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void get_kmeans(HeapTuple *tuple, ModelKMeans *modelKmeans)
|
||||
{
|
||||
Datum dtValue, dtName;
|
||||
bool isnull;
|
||||
char *strValue, *strName, *coordinates = NULL;
|
||||
uint32_t coefContainer;
|
||||
int offset = 0;
|
||||
WHCentroid *curseCent;
|
||||
|
||||
modelKmeans->model.algorithm = KMEANS;
|
||||
|
||||
/* coef */
|
||||
Datum dtCoefValues = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_coefValues, &isnull);
|
||||
ArrayType *arrValues = DatumGetArrayTypeP(dtCoefValues);
|
||||
ArrayIterator itValue = array_create_iterator(arrValues, 0);
|
||||
|
||||
Datum dtCoefNames = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_coefNames, &isnull);
|
||||
ArrayType *arrNames = DatumGetArrayTypeP(dtCoefNames);
|
||||
ArrayIterator itName = array_create_iterator(arrNames, 0);
|
||||
|
||||
while (array_iterate(itName, &dtName, &isnull)) {
|
||||
array_iterate(itValue, &dtValue, &isnull);
|
||||
strName = TextDatumGetCString(dtName);
|
||||
strValue = TextDatumGetCString(dtValue);
|
||||
coefContainer = atoi(strValue);
|
||||
if (strcmp(strName, "original_num_centroids") == 0) {
|
||||
modelKmeans->original_num_centroids = coefContainer;
|
||||
} else if (strcmp(strName, "actual_num_centroids") == 0) {
|
||||
modelKmeans->actual_num_centroids = coefContainer;
|
||||
} else if (strcmp(strName, "seed") == 0) {
|
||||
modelKmeans->seed = coefContainer;
|
||||
} else if (strcmp(strName, "dimension") == 0) {
|
||||
modelKmeans->dimension = coefContainer;
|
||||
} else if (strcmp(strName, "distance_function_id") == 0) {
|
||||
modelKmeans->distance_function_id = coefContainer;
|
||||
} else if (strcmp(strName, "coordinates") == 0) {
|
||||
coordinates = strValue;
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INTERNAL_ERROR),
|
||||
errmsg("the coef should not be here in KMEANS: %s", strName)));
|
||||
}
|
||||
}
|
||||
|
||||
modelKmeans->centroids =
|
||||
reinterpret_cast<WHCentroid *>(palloc0(sizeof(WHCentroid) * modelKmeans->actual_num_centroids));
|
||||
|
||||
/* describe */
|
||||
Datum dtDescribe = SysCacheGetAttr(DB4AI_MODEL, *tuple, Anum_gs_model_modeldescribe, &isnull);
|
||||
ArrayType *arrDescribe = DatumGetArrayTypeP(dtDescribe);
|
||||
ArrayIterator itDescribe = array_create_iterator(arrDescribe, 0);
|
||||
|
||||
while (array_iterate(itDescribe, &dtName, &isnull)) {
|
||||
curseCent = modelKmeans->centroids + offset;
|
||||
strName = TextDatumGetCString(dtName);
|
||||
coordinates = splitStringFillCoordinates(curseCent, coordinates, modelKmeans->dimension);
|
||||
splitStringFillCentroid(curseCent, strName);
|
||||
offset++;
|
||||
}
|
||||
}
|
||||
|
||||
void splitStringFillCentroid(WHCentroid *curseCent, char *strDescribe)
|
||||
{
|
||||
char *cur, *name, *context = NULL;
|
||||
Datum dtCur;
|
||||
|
||||
name = strtok_r(strDescribe, ":,", &context);
|
||||
cur = strtok_r(NULL, ":,", &context);
|
||||
while (cur != NULL and name != NULL) {
|
||||
if (strcmp(name, "id") == 0) {
|
||||
dtCur = string_to_datum(cur, INT8OID);
|
||||
curseCent->id = DatumGetUInt32(dtCur);
|
||||
} else if (strcmp(name, "objective_function") == 0) {
|
||||
dtCur = string_to_datum(cur, FLOAT8OID);
|
||||
curseCent->objective_function = DatumGetFloat8(dtCur);
|
||||
} else if (strcmp(name, "avg_distance_to_centroid") == 0) {
|
||||
dtCur = string_to_datum(cur, FLOAT8OID);
|
||||
curseCent->avg_distance_to_centroid = DatumGetFloat8(dtCur);
|
||||
} else if (strcmp(name, "min_distance_to_centroid") == 0) {
|
||||
dtCur = string_to_datum(cur, FLOAT8OID);
|
||||
curseCent->min_distance_to_centroid = DatumGetFloat8(dtCur);
|
||||
} else if (strcmp(name, "max_distance_to_centroid") == 0) {
|
||||
dtCur = string_to_datum(cur, FLOAT8OID);
|
||||
curseCent->max_distance_to_centroid = DatumGetFloat8(dtCur);
|
||||
} else if (strcmp(name, "std_dev_distance_to_centroid") == 0) {
|
||||
dtCur = string_to_datum(cur, FLOAT8OID);
|
||||
curseCent->std_dev_distance_to_centroid = DatumGetFloat8(dtCur);
|
||||
} else if (strcmp(name, "cluster_size") == 0) {
|
||||
dtCur = string_to_datum(cur, INT8OID);
|
||||
curseCent->cluster_size = DatumGetUInt64(dtCur);
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INTERNAL_ERROR),
|
||||
errmsg("this description should not be here in KMEANS: %s", cur)));
|
||||
}
|
||||
name = strtok_r(NULL, ":,", &context);
|
||||
cur = strtok_r(NULL, ":,", &context);
|
||||
}
|
||||
}
|
||||
|
||||
char *splitStringFillCoordinates(WHCentroid *curseCent, char *strCoordinates, int dimension)
|
||||
{
|
||||
char *cur, *context = NULL;
|
||||
Datum dtCur;
|
||||
int iter = 0;
|
||||
double *res = (double *)palloc0(dimension * sizeof(double));
|
||||
|
||||
while (iter < dimension) {
|
||||
if (iter == 0) {
|
||||
cur = strtok_r(strCoordinates, ")(,", &context);
|
||||
} else {
|
||||
cur = strtok_r(NULL, ")(,", &context);
|
||||
}
|
||||
|
||||
if (cur != NULL) {
|
||||
dtCur = string_to_datum(cur, FLOAT8OID);
|
||||
res[iter] = DatumGetFloat8(dtCur);
|
||||
iter++;
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("the Coordinates result seems not match their dimension or actual_num_centroids.")));
|
||||
}
|
||||
}
|
||||
curseCent->coordinates = res;
|
||||
|
||||
return context;
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
#This is the main CMAKE for build all components.
|
||||
|
||||
AUX_SOURCE_DIRECTORY(${CMAKE_CURRENT_SOURCE_DIR} TGT_commands_SRC)
|
||||
|
||||
set(TGT_commands_INC
|
||||
${PROJECT_OPENGS_DIR}/contrib/log_fdw
|
||||
${PROJECT_TRUNK_DIR}/distribute/bin/gds
|
||||
${PROJECT_SRC_DIR}/include/libcomm
|
||||
${PROJECT_SRC_DIR}/include
|
||||
${PROJECT_SRC_DIR}/lib/gstrace
|
||||
${LZ4_INCLUDE_PATH}
|
||||
${LIBCGROUP_INCLUDE_PATH}
|
||||
${LIBORC_INCLUDE_PATH}
|
||||
${EVENT_INCLUDE_PATH}
|
||||
${PROTOBUF_INCLUDE_PATH}
|
||||
${ZLIB_INCLUDE_PATH}
|
||||
)
|
||||
|
||||
set(commands_DEF_OPTIONS ${MACRO_OPTIONS})
|
||||
set(commands_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
|
||||
list(REMOVE_ITEM commands_COMPILE_OPTIONS -fPIC)
|
||||
set(commands_COMPILE_OPTIONS ${commands_COMPILE_OPTIONS} -std=c++14 -fPIE)
|
||||
set(commands_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
|
||||
add_static_objtarget(gausskernel_db4ai_commands TGT_commands_SRC TGT_commands_INC "${commands_DEF_OPTIONS}" "${commands_COMPILE_OPTIONS}" "${commands_LINK_OPTIONS}")
|
|
@ -0,0 +1,22 @@
|
|||
#---------------------------------------------------------------------------------------
|
||||
#
|
||||
# IDENTIFICATION
|
||||
# src/gausskernel/dbmind/commands/Makefile
|
||||
#
|
||||
# ---------------------------------------------------------------------------------------
|
||||
|
||||
subdir = src/gausskernel/dbmind/db4ai/commands
|
||||
top_builddir = ../../../../..
|
||||
include $(top_builddir)/src/Makefile.global
|
||||
|
||||
ifneq "$(MAKECMDGOALS)" "clean"
|
||||
ifneq "$(MAKECMDGOALS)" "distclean"
|
||||
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
|
||||
-include $(DEPEND)
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
OBJS = create_model.o predict_by.o
|
||||
|
||||
include $(top_srcdir)/src/gausskernel/common.mk
|
|
@ -0,0 +1,751 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* command.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/commands/create_model.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
#include "db4ai/create_model.h"
|
||||
|
||||
#include "postgres.h"
|
||||
#include "knl/knl_variable.h"
|
||||
|
||||
#include "db4ai/model_warehouse.h"
|
||||
#include "db4ai/hyperparameter_validation.h"
|
||||
#include "catalog/indexing.h"
|
||||
#include "executor/executor.h"
|
||||
#include "executor/nodeKMeans.h"
|
||||
#include "nodes/value.h"
|
||||
#include "parser/analyze.h"
|
||||
#include "rewrite/rewriteHandler.h"
|
||||
#include "utils/snapmgr.h"
|
||||
#include "tcop/tcopprot.h"
|
||||
#include "utils/lsyscache.h"
|
||||
#include "utils/rel.h"
|
||||
#include "workload/workload.h"
|
||||
#include "executor/nodeGD.h"
|
||||
#include "db4ai/aifuncs.h"
|
||||
#include "utils/builtins.h"
|
||||
|
||||
extern void exec_simple_plan(PlannedStmt *plan); // defined in postgres.cpp
|
||||
bool verify_pgarray(ArrayType const * pg_array, int32_t n); // defined in kmeans.cpp
|
||||
|
||||
/*
|
||||
* Common setup needed by both normal execution and EXPLAIN ANALYZE.
|
||||
* This setup is adapted from SetupForCreateTableAs
|
||||
*/
|
||||
static Query *setup_for_create_model(Query *query, /* IntoClause *into, */ const char *queryString,
|
||||
ParamListInfo params /* , DestReceiver *dest */)
|
||||
{
|
||||
List *rewritten = NIL;
|
||||
|
||||
Assert(query->commandType == CMD_SELECT);
|
||||
|
||||
/*
|
||||
* Parse analysis was done already, but we still have to run the rule
|
||||
* rewriter. We do not do AcquireRewriteLocks: we assume the query either
|
||||
* came straight from the parser, or suitable locks were acquired by
|
||||
* plancache.c.
|
||||
*
|
||||
* Because the rewriter and planner tend to scribble on the input, we make
|
||||
* a preliminary copy of the source querytree. This prevents problems in
|
||||
* the case that CTAS is in a portal or plpgsql function and is executed
|
||||
* repeatedly. (See also the same hack in EXPLAIN and PREPARE.)
|
||||
*/
|
||||
rewritten = QueryRewrite((Query *)copyObject(query));
|
||||
|
||||
/* SELECT should never rewrite to more or less than one SELECT query */
|
||||
if (list_length(rewritten) != 1) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INTERNAL_ERROR),
|
||||
errmsg("Unexpected rewrite result for CREATE MODEL statement")));
|
||||
}
|
||||
|
||||
query = (Query *)linitial(rewritten);
|
||||
|
||||
return query;
|
||||
}
|
||||
|
||||
// Create an GradientDescent execution node with a given configuration
|
||||
static GradientDescent *create_gd_node(AlgorithmML algorithm, List *hyperparameters, DestReceiverTrainModel *dest)
|
||||
{
|
||||
GradientDescent *gd_node = makeNode(GradientDescent);
|
||||
gd_node->algorithm = algorithm;
|
||||
|
||||
configure_hyperparameters(algorithm, hyperparameters, dest->model, gd_node);
|
||||
if (gd_node->seed == 0) {
|
||||
gd_node->seed = time(NULL); // it is not set to zero again (zero is the epoch in the past)
|
||||
update_model_hyperparameter(dest->model, "seed", INT4OID, Int32GetDatum(gd_node->seed));
|
||||
}
|
||||
|
||||
gd_node->plan.type = T_GradientDescent;
|
||||
gd_node->plan.targetlist = makeGradientDescentExpr(algorithm, nullptr, 1);
|
||||
dest->targetlist = gd_node->plan.targetlist;
|
||||
|
||||
return gd_node;
|
||||
}
|
||||
|
||||
// Add a GradientDescent operator at the root of the plan
|
||||
static PlannedStmt *add_GradientDescent_to_plan(PlannedStmt *plan, AlgorithmML algorithm, List *hyperparameters,
|
||||
DestReceiverTrainModel *dest)
|
||||
{
|
||||
GradientDescent *gd_node = create_gd_node(algorithm, hyperparameters, dest);
|
||||
gd_node->plan.lefttree = plan->planTree;
|
||||
plan->planTree = &gd_node->plan;
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
static DistanceFunction get_kmeans_distance(const char *distance_func)
|
||||
{
|
||||
DistanceFunction distance = KMEANS_L2_SQUARED;
|
||||
|
||||
if (strcmp(distance_func, "L1") == 0)
|
||||
distance = KMEANS_L1;
|
||||
else if (strcmp(distance_func, "L2") == 0)
|
||||
distance = KMEANS_L2;
|
||||
else if (strcmp(distance_func, "L2_Squared") == 0)
|
||||
distance = KMEANS_L2_SQUARED;
|
||||
else if (strcmp(distance_func, "Linf") == 0)
|
||||
distance = KMEANS_LINF;
|
||||
else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("No known distance function chosen. Current candidates are: "
|
||||
"L1, L2, L2_Squared (default), Linf")));
|
||||
}
|
||||
|
||||
return distance;
|
||||
}
|
||||
|
||||
static SeedingFunction get_kmeans_seeding(const char *seeding_func)
|
||||
{
|
||||
SeedingFunction seeding = KMEANS_RANDOM_SEED;
|
||||
|
||||
if (strcmp(seeding_func, "Random++") == 0)
|
||||
seeding = KMEANS_RANDOM_SEED;
|
||||
else if (strcmp(seeding_func, "KMeans||") == 0)
|
||||
seeding = KMEANS_BB;
|
||||
else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("No known seeding function chosen. Current candidates are: Random++ (default), KMeans||")));
|
||||
}
|
||||
return seeding;
|
||||
}
|
||||
|
||||
static KMeans *create_kmeans_node(AlgorithmML const algorithm, List *hyperparameters, DestReceiverTrainModel *dest)
|
||||
{
|
||||
KMeans *kmeans_node = makeNode(KMeans);
|
||||
char *distance_func = nullptr;
|
||||
char *seeding_func = nullptr;
|
||||
double tolerance = 0.;
|
||||
int32_t num_iterations = 0;
|
||||
int32_t num_centroids = 0;
|
||||
int32_t const max_num_centroids = 1000000;
|
||||
int32_t batch_size = 0;
|
||||
int32_t num_features = 0;
|
||||
int32_t external_seed = 0;
|
||||
int32_t verbosity = 0;
|
||||
auto kmeans_model = reinterpret_cast<ModelKMeans *>(dest->model);
|
||||
HyperparameterValidation validation;
|
||||
memset_s(&validation, sizeof(HyperparameterValidation), 0, sizeof(HyperparameterValidation));
|
||||
|
||||
kmeans_node->algorithm = algorithm;
|
||||
kmeans_node->plan.type = T_KMeans;
|
||||
kmeans_model->model.return_type = INT4OID;
|
||||
|
||||
set_hyperparameter<int32_t>("max_iterations", &num_iterations, hyperparameters, 10, dest->model, &validation);
|
||||
|
||||
if (unlikely(num_iterations <= 0)) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("max_iterations must be in [1, %d]", INT_MAX)));
|
||||
} else {
|
||||
kmeans_node->parameters.num_iterations = num_iterations;
|
||||
}
|
||||
set_hyperparameter<int32_t>("num_centroids", &num_centroids, hyperparameters, 10, dest->model, &validation);
|
||||
|
||||
if (unlikely((num_centroids <= 0) || (num_centroids > max_num_centroids))) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("num_centroids must be in [1, %d]", max_num_centroids)));
|
||||
} else {
|
||||
kmeans_node->parameters.num_centroids = num_centroids;
|
||||
}
|
||||
set_hyperparameter<double>("tolerance", &tolerance, hyperparameters, 0.00001, dest->model, &validation);
|
||||
|
||||
if (unlikely((tolerance <= 0.) || (tolerance > 1.))) {
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("tolerance must be in (0, 1.0]")));
|
||||
} else {
|
||||
kmeans_node->parameters.tolerance = tolerance;
|
||||
}
|
||||
set_hyperparameter<int32_t>("batch_size", &batch_size, hyperparameters, 10, dest->model, &validation);
|
||||
|
||||
if (unlikely((batch_size <= 0) || (batch_size > max_num_centroids))) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("batch_size must be in [1, %d]", max_num_centroids)));
|
||||
} else {
|
||||
kmeans_node->description.batch_size = batch_size;
|
||||
}
|
||||
set_hyperparameter<int32_t>("num_features", &num_features, hyperparameters, 2, dest->model, &validation);
|
||||
|
||||
if (unlikely((num_features <= 0) || (num_features > max_num_centroids))) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("num_features must be in [1, %d]", max_num_centroids)));
|
||||
} else {
|
||||
kmeans_node->description.n_features = num_features;
|
||||
}
|
||||
set_hyperparameter<char *>("distance_function", &distance_func, hyperparameters, "L2_Squared", dest->model,
|
||||
&validation);
|
||||
|
||||
kmeans_node->description.distance = get_kmeans_distance(distance_func);
|
||||
|
||||
set_hyperparameter<char *>("seeding_function", &seeding_func, hyperparameters, "Random++", dest->model,
|
||||
&validation);
|
||||
|
||||
kmeans_node->description.seeding = get_kmeans_seeding(seeding_func);
|
||||
|
||||
set_hyperparameter<int32_t>("verbose", &verbosity, hyperparameters, false, dest->model, &validation);
|
||||
|
||||
if (verbosity < 0 || verbosity > 2)
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Verbosity level must be between 0 (no output), 1 (less output), or 2 (full output)")));
|
||||
else
|
||||
kmeans_node->description.verbosity = static_cast<Verbosity>(verbosity);
|
||||
|
||||
/*
|
||||
* unfortunately the system parses an int64_t as T_Float whenever the it does not fit into a int32_t
|
||||
* thus, an int64_t might be internally parsed as a T_Integer or a T_Float depending on whether
|
||||
* the precision fits into an int32_t. thus, we accept a small seed (int32_t) that is xor'ed with
|
||||
* a random long internal seed.
|
||||
*/
|
||||
set_hyperparameter<int32_t>("seed", &external_seed, hyperparameters, 0, dest->model, &validation);
|
||||
|
||||
/*
|
||||
* the seed used for the algorithm is the xor of the seed provided by the user with
|
||||
* a random (but fixed) internal seed. as long as the internal seed is kept unchanged
|
||||
* results will be reproducible (see nodeKMeans.cpp)
|
||||
*/
|
||||
if (external_seed < 0)
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("seed must be in [0, %d]", INT_MAX)));
|
||||
else
|
||||
kmeans_node->parameters.external_seed = static_cast<uint64_t>(external_seed);
|
||||
|
||||
/*
|
||||
* these fields are propagated all the way to store_model and used for prediction
|
||||
* the value of fields of ModelKMeans not set here change during execution
|
||||
* and thus are set in the very end, when the model is about to be stored
|
||||
*/
|
||||
kmeans_model->dimension = kmeans_node->description.n_features;
|
||||
kmeans_model->original_num_centroids = kmeans_node->parameters.num_centroids;
|
||||
kmeans_model->distance_function_id = kmeans_node->description.distance;
|
||||
|
||||
pfree(distance_func);
|
||||
pfree(seeding_func);
|
||||
return kmeans_node;
|
||||
}
|
||||
|
||||
// Add a k-means operator at the root of the plan
|
||||
static PlannedStmt *add_kmeans_to_plan(PlannedStmt *plan, AlgorithmML algorithm, List *hyperparameters,
|
||||
DestReceiverTrainModel *dest)
|
||||
{
|
||||
if (unlikely(algorithm != KMEANS)) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Algorithm is not the expected %u (k-means). Provided %u", KMEANS, algorithm)));
|
||||
}
|
||||
|
||||
KMeans *kmeans_node = create_kmeans_node(algorithm, hyperparameters, dest);
|
||||
|
||||
kmeans_node->plan.lefttree = plan->planTree;
|
||||
plan->planTree = &kmeans_node->plan;
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
|
||||
// Add the ML algorithm at the root of the plan according to the CreateModelStmt
|
||||
static PlannedStmt *add_create_model_to_plan(CreateModelStmt *stmt, PlannedStmt *plan, DestReceiverTrainModel *dest)
|
||||
{
|
||||
PlannedStmt *result = NULL;
|
||||
switch (stmt->algorithm) {
|
||||
case LOGISTIC_REGRESSION:
|
||||
case SVM_CLASSIFICATION:
|
||||
case LINEAR_REGRESSION: {
|
||||
result = add_GradientDescent_to_plan(plan, stmt->algorithm, stmt->hyperparameters, dest);
|
||||
break;
|
||||
}
|
||||
case KMEANS: {
|
||||
result = add_kmeans_to_plan(plan, stmt->algorithm, stmt->hyperparameters, dest);
|
||||
break;
|
||||
}
|
||||
case INVALID_ALGORITHM_ML:
|
||||
default: {
|
||||
char *s = "logistic_regression, svm_classification, linear_regression, kmeans";
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Architecture %s is not supported. Supported architectures: %s", stmt->architecture, s)));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Create the query plan with the appropriate machine learning model
|
||||
PlannedStmt *plan_create_model(CreateModelStmt *stmt, const char *query_string, ParamListInfo params,
|
||||
DestReceiver *dest)
|
||||
{
|
||||
Query *query = (Query *)stmt->select_query;
|
||||
PlannedStmt *plan = NULL;
|
||||
|
||||
query = setup_for_create_model(query, query_string, params);
|
||||
|
||||
/* plan the query */
|
||||
plan = pg_plan_query(query, 0, params);
|
||||
|
||||
// Inject the GradientDescent node at the root of the plan
|
||||
plan = add_create_model_to_plan(stmt, plan, (DestReceiverTrainModel *)dest);
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
// Prepare the DestReceiver for training
|
||||
void configure_dest_receiver_train_model(DestReceiverTrainModel *dest, AlgorithmML algorithm, const char *model_name,
|
||||
const char *sql)
|
||||
{
|
||||
switch (algorithm) {
|
||||
case LOGISTIC_REGRESSION:
|
||||
case LINEAR_REGRESSION:
|
||||
case SVM_CLASSIFICATION: {
|
||||
dest->model = (Model *)palloc0(sizeof(ModelGradientDescent));
|
||||
break;
|
||||
}
|
||||
case KMEANS: {
|
||||
dest->model = (Model *)palloc0(sizeof(ModelKMeans));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Unsupported model type in model warehouse %d", algorithm)));
|
||||
}
|
||||
}
|
||||
dest->model->algorithm = algorithm;
|
||||
dest->model->model_name = pstrdup(model_name);
|
||||
dest->model->sql = sql;
|
||||
dest->targetlist = nullptr;
|
||||
}
|
||||
|
||||
|
||||
// /*
|
||||
// * ExecCreateTableAs -- execute a CREATE TABLE AS command
|
||||
// */
|
||||
void exec_create_model(CreateModelStmt *stmt, const char *queryString, ParamListInfo params, char *completionTag)
|
||||
{
|
||||
#ifdef ENABLE_MULTIPLE_NODES
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("No support for distributed scenarios yet.")));
|
||||
#endif
|
||||
DestReceiverTrainModel *dest = NULL;
|
||||
|
||||
PlannedStmt *plan = NULL;
|
||||
QueryDesc *queryDesc = NULL;
|
||||
ScanDirection dir;
|
||||
|
||||
/*
|
||||
* Create the tuple receiver object and insert hyperp it will need
|
||||
*/
|
||||
dest = (DestReceiverTrainModel *)CreateDestReceiver(DestTrainModel);
|
||||
configure_dest_receiver_train_model(dest, (AlgorithmML)stmt->algorithm, stmt->model, queryString);
|
||||
|
||||
plan = plan_create_model(stmt, queryString, params, (DestReceiver *)dest);
|
||||
|
||||
/*
|
||||
* Use a snapshot with an updated command ID to ensure this query sees
|
||||
* results of any previously executed queries. (This could only matter if
|
||||
* the planner executed an allegedly-stable function that changed the
|
||||
* database contents, but let's do it anyway to be parallel to the EXPLAIN
|
||||
* code path.)
|
||||
*/
|
||||
PushCopiedSnapshot(GetActiveSnapshot());
|
||||
UpdateActiveSnapshotCommandId();
|
||||
|
||||
/* Create a QueryDesc, redirecting output to our tuple receiver */
|
||||
queryDesc = CreateQueryDesc(plan, queryString, GetActiveSnapshot(), InvalidSnapshot, &dest->dest, params, 0);
|
||||
|
||||
#ifdef ENABLE_MULTIPLE_NODES
|
||||
if (ENABLE_WORKLOAD_CONTROL && (IS_PGXC_COORDINATOR)) {
|
||||
#else
|
||||
if (ENABLE_WORKLOAD_CONTROL) {
|
||||
#endif
|
||||
/* Check if need track resource */
|
||||
u_sess->exec_cxt.need_track_resource = WLMNeedTrackResource(queryDesc);
|
||||
}
|
||||
|
||||
/* call ExecutorStart to prepare the plan for execution */
|
||||
ExecutorStart(queryDesc, 0);
|
||||
|
||||
/* workload client manager */
|
||||
if (ENABLE_WORKLOAD_CONTROL) {
|
||||
WLMInitQueryPlan(queryDesc);
|
||||
dywlm_client_manager(queryDesc);
|
||||
}
|
||||
|
||||
dir = ForwardScanDirection;
|
||||
|
||||
/* run the plan */
|
||||
ExecutorRun(queryDesc, dir, 0L);
|
||||
|
||||
/* save the rowcount if we're given a completionTag to fill */
|
||||
if (completionTag != NULL) {
|
||||
errno_t rc;
|
||||
rc = snprintf_s(completionTag, COMPLETION_TAG_BUFSIZE, COMPLETION_TAG_BUFSIZE - 1,
|
||||
"MODEL CREATED. PROCESSED %lu", queryDesc->estate->es_processed);
|
||||
securec_check_ss(rc, "\0", "\0");
|
||||
}
|
||||
|
||||
/* and clean up */
|
||||
ExecutorFinish(queryDesc);
|
||||
ExecutorEnd(queryDesc);
|
||||
|
||||
FreeQueryDesc(queryDesc);
|
||||
|
||||
PopActiveSnapshot();
|
||||
}
|
||||
|
||||
static void store_gd_expr_in_model(Datum dt, Oid type, int col, GradientDescentExprField field, Model *model,
|
||||
ModelGradientDescent *model_gd, TupleDesc tupdesc)
|
||||
{
|
||||
switch (field) {
|
||||
case GD_EXPR_ALGORITHM:
|
||||
Assert(type == INT4OID);
|
||||
model_gd->model.algorithm = (AlgorithmML)DatumGetInt32(dt);
|
||||
break;
|
||||
case GD_EXPR_OPTIMIZER:
|
||||
break; // Ignore field
|
||||
case GD_EXPR_RESULT_TYPE:
|
||||
Assert(type == OIDOID);
|
||||
model->return_type = DatumGetUInt32(dt);
|
||||
break;
|
||||
case GD_EXPR_NUM_ITERATIONS:
|
||||
Assert(type == INT4OID);
|
||||
model->num_actual_iterations = DatumGetInt32(dt);
|
||||
break;
|
||||
case GD_EXPR_EXEC_TIME_MSECS:
|
||||
Assert(type == FLOAT4OID);
|
||||
model->exec_time_secs = DatumGetFloat4(dt) / 1000.0;
|
||||
break;
|
||||
case GD_EXPR_PROCESSED_TUPLES:
|
||||
Assert(type == INT4OID);
|
||||
model->processed_tuples = DatumGetInt32(dt);
|
||||
break;
|
||||
case GD_EXPR_DISCARDED_TUPLES:
|
||||
Assert(type == INT4OID);
|
||||
model->discarded_tuples = DatumGetInt32(dt);
|
||||
break;
|
||||
case GD_EXPR_WEIGHTS:
|
||||
Assert(type == FLOAT4ARRAYOID);
|
||||
model_gd->weights = datumCopy(dt, tupdesc->attrs[col]->attbyval, tupdesc->attrs[col]->attlen);
|
||||
break;
|
||||
case GD_EXPR_CATEGORIES: {
|
||||
ArrayType *arr = (ArrayType *)DatumGetPointer(dt);
|
||||
model_gd->ncategories = ARR_DIMS(arr)[0];
|
||||
model_gd->categories =
|
||||
datumCopy(dt, tupdesc->attrs[col]->attbyval, tupdesc->attrs[col]->attlen);
|
||||
} break;
|
||||
default:
|
||||
(void)type;
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Model warehouse for GradientDescent field %d not implemented", field)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void store_tuple_gd_in_model_warehouse(TupleTableSlot *slot, DestReceiverTrainModel *dest)
|
||||
{
|
||||
Assert(dest->targetlist != nullptr);
|
||||
|
||||
TupleDesc tupdesc = slot->tts_tupleDescriptor;
|
||||
Model *model = dest->model;
|
||||
model->pre_time_secs = 0.0;
|
||||
|
||||
ModelGradientDescent *model_gd = (ModelGradientDescent *)model;
|
||||
model_gd->ncategories = 0;
|
||||
model_gd->model.algorithm = INVALID_ALGORITHM_ML; // undefined
|
||||
|
||||
int col = 0;
|
||||
ListCell *lc;
|
||||
foreach (lc, dest->targetlist) {
|
||||
TargetEntry *target = lfirst_node(TargetEntry, lc);
|
||||
GradientDescentExpr *expr = (GradientDescentExpr *)target->expr;
|
||||
if (!slot->tts_isnull[col]) {
|
||||
Datum dt = slot->tts_values[col];
|
||||
Oid type = tupdesc->attrs[col]->atttypid;
|
||||
if ((expr->field & GD_EXPR_SCORE) != 0) {
|
||||
Assert(type == FLOAT4OID);
|
||||
TrainingScore *score = (TrainingScore *)palloc0(sizeof(TrainingScore));
|
||||
score->name = pstrdup(target->resname);
|
||||
score->value = DatumGetFloat4(dt);
|
||||
model->scores = lappend(model->scores, score);
|
||||
} else {
|
||||
store_gd_expr_in_model(dt, type, col, expr->field, model, model_gd, tupdesc);
|
||||
}
|
||||
}
|
||||
col++;
|
||||
}
|
||||
}
|
||||
|
||||
static void store_kmeans_data_in_model(uint32_t natts, TupleDesc tupdesc, Datum *values, bool *nulls,
|
||||
Model *model, ModelKMeans *model_kmeans)
|
||||
{
|
||||
ArrayType *centroid_ids = nullptr;
|
||||
ArrayType *centroid_coordinates = nullptr;
|
||||
ArrayType *objective_functions = nullptr;
|
||||
ArrayType *avg_distances = nullptr;
|
||||
ArrayType *min_distances = nullptr;
|
||||
ArrayType *max_distances = nullptr;
|
||||
ArrayType *std_dev_distances = nullptr;
|
||||
ArrayType *cluster_sizes = nullptr;
|
||||
/* these are the inner-facing arrays */
|
||||
int32_t *centroid_ids_data = nullptr;
|
||||
double *centroid_coordiates_data = nullptr;
|
||||
double *objective_functions_data = nullptr;
|
||||
double *avg_distances_data = nullptr;
|
||||
double *min_distances_data = nullptr;
|
||||
double *max_distances_data = nullptr;
|
||||
double *std_dev_distances_data = nullptr;
|
||||
int64_t *cluster_sizes_data = nullptr;
|
||||
Oid oid = 0;
|
||||
Datum attr = 0;
|
||||
|
||||
/*
|
||||
* for tuple at a time we only use one centroid at a time
|
||||
*/
|
||||
for (uint32_t a = 0; a < natts; ++a) {
|
||||
if (unlikely(nulls[a])) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
|
||||
errmsg("Encountered null attribute %u when serializing k-means model when it should not", a)));
|
||||
}
|
||||
|
||||
oid = tupdesc->attrs[a]->atttypid;
|
||||
attr = values[a];
|
||||
|
||||
/*
|
||||
* this switch has to match exactly the schema of the row we return (see nodeKMeans.cpp)
|
||||
* there is a single row (quite big in general) thus the switch executes only once
|
||||
*/
|
||||
switch (a) {
|
||||
case 0:
|
||||
// centroids ids of type INT4ARRAYOID
|
||||
Assert(oid == INT4ARRAYOID);
|
||||
centroid_ids = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
|
||||
centroid_ids_data = reinterpret_cast<int32_t *>(ARR_DATA_PTR(centroid_ids));
|
||||
break;
|
||||
case 1:
|
||||
// centroids coordinates of type FLOAT8ARRAYOID
|
||||
Assert(oid == FLOAT8ARRAYOID);
|
||||
/*
|
||||
* in tuple at a time, this memory reference is valid until we store the centroid
|
||||
* in the model warehouse
|
||||
*/
|
||||
centroid_coordinates = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
|
||||
centroid_coordiates_data = reinterpret_cast<double *>(ARR_DATA_PTR(centroid_coordinates));
|
||||
break;
|
||||
case 2:
|
||||
// value of the objective functions (per cluster) of type FLOAT8ARRAYOID
|
||||
Assert(oid == FLOAT8ARRAYOID);
|
||||
objective_functions = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
|
||||
objective_functions_data = reinterpret_cast<double *>(ARR_DATA_PTR(objective_functions));
|
||||
break;
|
||||
case 3:
|
||||
// avg distance of the clusters of type FLOAT8ARRAYOID
|
||||
Assert(oid == FLOAT8ARRAYOID);
|
||||
avg_distances = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
|
||||
avg_distances_data = reinterpret_cast<double *>(ARR_DATA_PTR(avg_distances));
|
||||
break;
|
||||
case 4:
|
||||
// min distance of the clusters of type FLOAT8ARRAYOID
|
||||
Assert(oid == FLOAT8ARRAYOID);
|
||||
min_distances = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
|
||||
min_distances_data = reinterpret_cast<double *>(ARR_DATA_PTR(min_distances));
|
||||
break;
|
||||
case 5:
|
||||
// max distance of the clusters of type FLOAT8ARRAYOID
|
||||
Assert(oid == FLOAT8ARRAYOID);
|
||||
max_distances = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
|
||||
max_distances_data = reinterpret_cast<double *>(ARR_DATA_PTR(max_distances));
|
||||
break;
|
||||
case 6:
|
||||
// standard deviation of clusters of type FLOAT8ARRAYOID
|
||||
Assert(oid == FLOAT8ARRAYOID);
|
||||
std_dev_distances = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
|
||||
std_dev_distances_data = reinterpret_cast<double *>(ARR_DATA_PTR(std_dev_distances));
|
||||
break;
|
||||
case 7:
|
||||
// cluster sizes of type INT8ARRAYOID
|
||||
Assert(oid == INT8ARRAYOID);
|
||||
cluster_sizes = reinterpret_cast<ArrayType *>(DatumGetPointer(attr));
|
||||
cluster_sizes_data = reinterpret_cast<int64_t *>(ARR_DATA_PTR(cluster_sizes));
|
||||
break;
|
||||
case 8:
|
||||
// num good points of type INT8OID
|
||||
Assert(oid == INT8OID);
|
||||
model->processed_tuples = DatumGetInt64(attr);
|
||||
break;
|
||||
case 9:
|
||||
// num bad point of type INT8OID
|
||||
Assert(oid == INT8OID);
|
||||
model->discarded_tuples = DatumGetInt64(attr);
|
||||
break;
|
||||
case 10:
|
||||
// seedings time (secs) of type FLOAT8OID
|
||||
Assert(oid == FLOAT8OID);
|
||||
model->pre_time_secs = DatumGetFloat8(attr);
|
||||
break;
|
||||
case 11:
|
||||
// execution time (secs) of type FLOAT8OID
|
||||
Assert(oid == FLOAT8OID);
|
||||
model->exec_time_secs = DatumGetFloat8(attr);
|
||||
break;
|
||||
case 12:
|
||||
// actual number of iterations INT4OID
|
||||
Assert(oid == INT4OID);
|
||||
model->num_actual_iterations = DatumGetInt32(attr);
|
||||
break;
|
||||
case 13:
|
||||
// actual number of centroids INT4OID
|
||||
Assert(oid == INT4OID);
|
||||
model_kmeans->actual_num_centroids = DatumGetInt32(attr);
|
||||
break;
|
||||
case 14:
|
||||
// seed used for computations
|
||||
Assert(oid == INT8OID);
|
||||
model_kmeans->seed = DatumGetInt64(attr);
|
||||
break;
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Unknown attribute %u when serializing k-means model", a)));
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t const actual_num_centroids = model_kmeans->actual_num_centroids;
|
||||
uint32_t const dimension = model_kmeans->dimension;
|
||||
uint32_t centroid_coordinates_offset = 0;
|
||||
WHCentroid *current_centroid = nullptr;
|
||||
|
||||
/*
|
||||
* at this point we have extracted all the attributes and the memory representation
|
||||
* of the model can be constructed so that it can be stored in the model warehouse
|
||||
*/
|
||||
model_kmeans->centroids = reinterpret_cast<WHCentroid *>(palloc0(sizeof(WHCentroid) * actual_num_centroids));
|
||||
|
||||
/*
|
||||
* we fill in the information of every centroid
|
||||
*/
|
||||
for (uint32_t current_centroid_idx = 0; current_centroid_idx < actual_num_centroids; ++current_centroid_idx) {
|
||||
current_centroid = model_kmeans->centroids + current_centroid_idx;
|
||||
current_centroid->id = centroid_ids_data[current_centroid_idx];
|
||||
current_centroid->objective_function = objective_functions_data[current_centroid_idx];
|
||||
current_centroid->avg_distance_to_centroid = avg_distances_data[current_centroid_idx];
|
||||
current_centroid->min_distance_to_centroid = min_distances_data[current_centroid_idx];
|
||||
current_centroid->max_distance_to_centroid = max_distances_data[current_centroid_idx];
|
||||
current_centroid->std_dev_distance_to_centroid = std_dev_distances_data[current_centroid_idx];
|
||||
current_centroid->cluster_size = cluster_sizes_data[current_centroid_idx];
|
||||
current_centroid->coordinates = centroid_coordiates_data + centroid_coordinates_offset;
|
||||
centroid_coordinates_offset += dimension;
|
||||
}
|
||||
}
|
||||
|
||||
static void store_tuple_kmeans_in_model_warehouse(TupleTableSlot *slot, DestReceiverTrainModel *dest)
|
||||
{
|
||||
/*
|
||||
* sanity checks
|
||||
*/
|
||||
Assert(slot != NULL);
|
||||
Assert(!slot->tts_isempty);
|
||||
Assert(slot->tts_nvalid == NUM_ATTR_OUTPUT);
|
||||
Assert(slot->tts_tupleDescriptor != NULL);
|
||||
Assert(!TTS_HAS_PHYSICAL_TUPLE(slot));
|
||||
|
||||
TupleDesc tupdesc = slot->tts_tupleDescriptor;
|
||||
auto model_kmeans = reinterpret_cast<ModelKMeans *>(dest->model);
|
||||
Model *model = &model_kmeans->model;
|
||||
|
||||
if (unlikely(slot->tts_isempty))
|
||||
return;
|
||||
|
||||
uint32_t const natts = slot->tts_nvalid;
|
||||
/*
|
||||
* the slot contains a virtual tuple and thus we can access its attributs directly
|
||||
*/
|
||||
Datum *values = slot->tts_values;
|
||||
bool *nulls = slot->tts_isnull;
|
||||
|
||||
if (unlikely(!values && !nulls)) {
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE), errmsg("Empty arrays values and nulls")));
|
||||
}
|
||||
|
||||
store_kmeans_data_in_model(natts, tupdesc, values, nulls, model, model_kmeans);
|
||||
}
|
||||
|
||||
static void store_tuple_in_model_warehouse(TupleTableSlot *slot, DestReceiver *self)
|
||||
{
|
||||
DestReceiverTrainModel *dest = (DestReceiverTrainModel *)self;
|
||||
Model *model = dest->model;
|
||||
|
||||
switch (model->algorithm) {
|
||||
case LOGISTIC_REGRESSION:
|
||||
case SVM_CLASSIFICATION:
|
||||
case LINEAR_REGRESSION:
|
||||
store_tuple_gd_in_model_warehouse(slot, dest);
|
||||
break;
|
||||
|
||||
case KMEANS:
|
||||
store_tuple_kmeans_in_model_warehouse(slot, dest);
|
||||
break;
|
||||
|
||||
case INVALID_ALGORITHM_ML:
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Unsupported model type %d", static_cast<int>(model->algorithm))));
|
||||
break;
|
||||
}
|
||||
|
||||
store_model(model);
|
||||
}
|
||||
|
||||
|
||||
static void do_nothing_startup(DestReceiver *self, int operation, TupleDesc typehyperp)
|
||||
{
|
||||
/* do nothing */
|
||||
}
|
||||
|
||||
|
||||
static void do_nothing_cleanup(DestReceiver *self)
|
||||
{
|
||||
/* this is used for both shutdown and destroy methods */
|
||||
}
|
||||
|
||||
DestReceiver *CreateTrainModelDestReceiver()
|
||||
{
|
||||
DestReceiverTrainModel *dr = (DestReceiverTrainModel *)palloc0(sizeof(DestReceiverTrainModel));
|
||||
DestReceiver *result = &dr->dest;
|
||||
|
||||
result->rStartup = do_nothing_startup;
|
||||
result->receiveSlot = store_tuple_in_model_warehouse;
|
||||
result->rShutdown = do_nothing_cleanup;
|
||||
result->rDestroy = do_nothing_cleanup;
|
||||
result->mydest = DestTrainModel;
|
||||
|
||||
return result;
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/commands/predict_by.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/predict_by.h"
|
||||
|
||||
#include "catalog/pg_type.h"
|
||||
#include "db4ai/model_warehouse.h"
|
||||
#include "nodes/parsenodes_common.h"
|
||||
#include "parser/parse_expr.h"
|
||||
#include "utils/array.h"
|
||||
#include "utils/builtins.h"
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
#define DEBUG_MODEL_RETURN_TYPE INT4OID // Set manually the return type of model, until available from catalog
|
||||
|
||||
/*
|
||||
* functions relevant to k-means and defined in kmeans.cpp
|
||||
*/
|
||||
ModelPredictor kmeans_predict_prepare(Model const * model);
|
||||
Datum kmeans_predict(ModelPredictor model, Datum *data, bool *nulls, Oid *types, int nargs);
|
||||
|
||||
struct PredictionByData {
|
||||
PredictorInterface *api;
|
||||
ModelPredictor model_predictor;
|
||||
};
|
||||
|
||||
static PredictionByData *initialize_predict_by_data(Model *model)
|
||||
{
|
||||
PredictionByData *result = (PredictionByData *)palloc0(sizeof(PredictionByData));
|
||||
|
||||
// Initialize the API handlers
|
||||
switch (model->algorithm) {
|
||||
case LOGISTIC_REGRESSION:
|
||||
case SVM_CLASSIFICATION:
|
||||
case LINEAR_REGRESSION: {
|
||||
result->api = (PredictorInterface *)palloc0(sizeof(PredictorInterface));
|
||||
result->api->predict = gd_predict;
|
||||
result->api->prepare = gd_predict_prepare;
|
||||
break;
|
||||
}
|
||||
case KMEANS: {
|
||||
result->api = (PredictorInterface *)palloc0(sizeof(PredictorInterface));
|
||||
result->api->predict = kmeans_predict;
|
||||
result->api->prepare = kmeans_predict_prepare;
|
||||
break;
|
||||
}
|
||||
case INVALID_ALGORITHM_ML:
|
||||
default: {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Model type %d is not supported for prediction", (int)model->algorithm)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare the in memory version of the model for efficient prediction
|
||||
result->model_predictor = result->api->prepare(model);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Datum db4ai_predict_by(PG_FUNCTION_ARGS)
|
||||
{
|
||||
// First argument is the model, the following ones are the inputs to the model predictor
|
||||
if (PG_NARGS() < 2) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("INVALID Number of parameters %d", PG_NARGS())));
|
||||
}
|
||||
|
||||
int var_args_size = PG_NARGS() - 1;
|
||||
Datum *var_args = &fcinfo->arg[1]; // We skip the model name
|
||||
bool *nulls = &fcinfo->argnull[1];
|
||||
Oid *types = &fcinfo->argTypes[1];
|
||||
|
||||
PredictionByData *prediction_by_data;
|
||||
if (fcinfo->flinfo->fn_extra != NULL) {
|
||||
prediction_by_data = (PredictionByData *)fcinfo->flinfo->fn_extra;
|
||||
} else {
|
||||
// Initialize the model, and save the object as state of the function.
|
||||
// The memory allocation is done in the function context. So, the model
|
||||
// will be deallocated automatically when the function finishes
|
||||
MemoryContext oldContext = MemoryContextSwitchTo(fcinfo->flinfo->fn_mcxt);
|
||||
text *model_name_text = PG_GETARG_TEXT_P(0);
|
||||
char *model_name = text_to_cstring(model_name_text);
|
||||
|
||||
Model *model = get_model(model_name, false);
|
||||
prediction_by_data = initialize_predict_by_data(model);
|
||||
fcinfo->flinfo->fn_extra = prediction_by_data;
|
||||
pfree(model_name);
|
||||
|
||||
MemoryContextSwitchTo(oldContext);
|
||||
}
|
||||
|
||||
Datum result =
|
||||
prediction_by_data->api->predict(prediction_by_data->model_predictor, var_args, nulls, types, var_args_size);
|
||||
|
||||
PG_RETURN_DATUM(result);
|
||||
}
|
||||
|
||||
// These functions are only a wrapper for the generic function. We need this wrapper
|
||||
// to be compliant with openGauss type system that specifies the return type
|
||||
Datum db4ai_predict_by_bool(PG_FUNCTION_ARGS)
|
||||
{
|
||||
return db4ai_predict_by(fcinfo);
|
||||
}
|
||||
|
||||
Datum db4ai_predict_by_int32(PG_FUNCTION_ARGS)
|
||||
{
|
||||
return db4ai_predict_by(fcinfo);
|
||||
}
|
||||
|
||||
Datum db4ai_predict_by_int64(PG_FUNCTION_ARGS)
|
||||
{
|
||||
return db4ai_predict_by(fcinfo);
|
||||
}
|
||||
|
||||
Datum db4ai_predict_by_float4(PG_FUNCTION_ARGS)
|
||||
{
|
||||
return db4ai_predict_by(fcinfo);
|
||||
}
|
||||
|
||||
Datum db4ai_predict_by_float8(PG_FUNCTION_ARGS)
|
||||
{
|
||||
return db4ai_predict_by(fcinfo);
|
||||
}
|
||||
|
||||
Datum db4ai_predict_by_numeric(PG_FUNCTION_ARGS)
|
||||
{
|
||||
return db4ai_predict_by(fcinfo);
|
||||
}
|
||||
|
||||
Datum db4ai_predict_by_text(PG_FUNCTION_ARGS)
|
||||
{
|
||||
return db4ai_predict_by(fcinfo);
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
# executor.cmake
|
||||
set(TGT_executor_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distance_functions.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fp_ops.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hyperparameter_validation.cpp
|
||||
)
|
||||
set(TGT_executor_INC
|
||||
${PROJECT_OPENGS_DIR}/contrib/log_fdw
|
||||
${PROJECT_TRUNK_DIR}/distribute/bin/gds
|
||||
${PROJECT_SRC_DIR}/include/libcomm
|
||||
${PROJECT_SRC_DIR}/include
|
||||
${PROJECT_SRC_DIR}/lib/gstrace
|
||||
${LZ4_INCLUDE_PATH}
|
||||
${LIBCGROUP_INCLUDE_PATH}
|
||||
${LIBORC_INCLUDE_PATH}
|
||||
${EVENT_INCLUDE_PATH}
|
||||
${PROTOBUF_INCLUDE_PATH}
|
||||
${ZLIB_INCLUDE_PATH}
|
||||
)
|
||||
set(executor_DEF_OPTIONS ${MACRO_OPTIONS})
|
||||
set(executor_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
|
||||
list(REMOVE_ITEM executor_COMPILE_OPTIONS -fPIC)
|
||||
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64")
|
||||
set(executor_COMPILE_OPTIONS ${executor_COMPILE_OPTIONS} -std=c++14 -fPIE -mavx)
|
||||
else()
|
||||
set(executor_COMPILE_OPTIONS ${executor_COMPILE_OPTIONS} -std=c++14 -fPIE)
|
||||
endif()
|
||||
set(executor_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
|
||||
add_static_objtarget(gausskernel_db4ai_executor TGT_executor_SRC TGT_executor_INC "${executor_DEF_OPTIONS}" "${executor_COMPILE_OPTIONS}" "${executor_LINK_OPTIONS}")
|
||||
|
||||
add_subdirectory(gd)
|
||||
add_subdirectory(kmeans)
|
|
@ -0,0 +1,33 @@
|
|||
#---------------------------------------------------------------------------------------
|
||||
#
|
||||
# IDENTIFICATION
|
||||
# src/gausskernel/dbmind/db4ai/executor/Makefile
|
||||
#
|
||||
# ---------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
subdir = src/gausskernel/dbmind/db4ai/executor
|
||||
top_builddir = ../../../../..
|
||||
|
||||
include $(top_builddir)/src/Makefile.global
|
||||
|
||||
PLATFORM_ARCH = $(shell uname -p)
|
||||
ifeq ($(PLATFORM_ARCH),x86_64)
|
||||
override CPPFLAGS += -mavx
|
||||
endif
|
||||
|
||||
ifneq "$(MAKECMDGOALS)" "clean"
|
||||
ifneq "$(MAKECMDGOALS)" "distclean"
|
||||
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
|
||||
-include $(DEPEND)
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
SUBDIRS = gd kmeans
|
||||
OBJS = fp_ops.o distance_functions.o hyperparameter_validation.o
|
||||
|
||||
include $(top_srcdir)/src/gausskernel/common.mk
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,505 @@
|
|||
/* *
|
||||
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
|
||||
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
---------------------------------------------------------------------------------------
|
||||
|
||||
distance_functions.cpp
|
||||
Current set of distance functions that can be used (for k-means for example)
|
||||
|
||||
IDENTIFICATION
|
||||
src/gausskernel/dbmind/db4ai/executor/distance_functions.cpp
|
||||
|
||||
---------------------------------------------------------------------------------------
|
||||
* */
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
#include "db4ai/distance_functions.h"
|
||||
#include "db4ai/fp_ops.h"
|
||||
#include "db4ai/db4ai_cpu.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#if defined(__x86_64__) && defined(__SSE3__)
|
||||
#include <pmmintrin.h>
|
||||
#elif defined(__aarch64__) && defined(__ARM_NEON)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
|
||||
/*
|
||||
* L1 distance (Manhattan)
|
||||
* We sum using cascaded summation
|
||||
* This version is unvectorized and is used in case vectorized instructions
|
||||
* are not available or for the the case that the dimension is not a multiple
|
||||
* of the width of the registers
|
||||
*/
|
||||
static force_inline double l1_non_vectorized(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
double term = 0.;
|
||||
double term_correction = 0.;
|
||||
double distance = 0.;
|
||||
double distance_correction = 0.;
|
||||
|
||||
twoDiff(q[0], p[0], &term, &term_correction);
|
||||
term += term_correction;
|
||||
// absolute value of the difference (hopefully done by clearing the sign bit)
|
||||
distance = std::abs(term);
|
||||
|
||||
for (uint32_t d = 1; d < dimension; ++d) {
|
||||
twoDiff(q[d], p[d], &term, &term_correction);
|
||||
term += term_correction;
|
||||
term = std::abs(term);
|
||||
twoSum(distance, term, &distance, &term_correction);
|
||||
distance_correction += term_correction;
|
||||
}
|
||||
|
||||
return distance + distance_correction;
|
||||
}
|
||||
|
||||
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
|
||||
/*
|
||||
* L1 distance (Manhattan)
|
||||
* This version is vectorized using SSE or NEON and is used in case only 128-bit
|
||||
* vectorized instructions are available
|
||||
*/
|
||||
static double l1_128(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
if (unlikely(dimension == 0))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("L1 distance (128-bit): dimension must be larger than 0")));
|
||||
|
||||
double distance[2] = {0.};
|
||||
/* the result of dimension modulo 2 */
|
||||
uint32_t const dimension_remainder = dimension & 0x1U;
|
||||
uint32_t const offset = 2U;
|
||||
double distance_first_terms = 0.;
|
||||
double local_distance_correction = 0.;
|
||||
double global_distance_correction = 0.;
|
||||
|
||||
/*
|
||||
* this will compute the very first terms of the distance
|
||||
* (the ones that cannot be fully computed using simd registers, and
|
||||
* thus have to be done with scalar computations)
|
||||
*/
|
||||
if (dimension_remainder > 0) {
|
||||
/*
|
||||
* this is gonna compute the 1-dimensional distance (the very first
|
||||
* term in the whole distance computation)
|
||||
*/
|
||||
distance_first_terms = l1_non_vectorized(p, q, dimension_remainder);
|
||||
}
|
||||
|
||||
/*
|
||||
* if the dimension is < 2 then the term above is the whole distance
|
||||
* otherwise we have at least one simd register we can fill
|
||||
*/
|
||||
if (unlikely(dimension < offset))
|
||||
return distance_first_terms;
|
||||
|
||||
#if defined(__x86_64__)
|
||||
__m128d const zero = _mm_setzero_pd();
|
||||
__m128d const sign_mask = _mm_set1_pd(-0.0f);
|
||||
__m128d sum = zero;
|
||||
__m128d absolute_value;
|
||||
__m128d sub;
|
||||
__m128d p128, q128;
|
||||
#else // aarch64
|
||||
float64x2_t const zero = vdupq_n_f64(0);
|
||||
float64x2_t sum = zero;
|
||||
float64x2_t absolute_value;
|
||||
float64x2_t sub;
|
||||
float64x2_t p128, q128;
|
||||
#endif
|
||||
|
||||
Assert(((dimension - dimension_remainder) & 0x1U) == 0U);
|
||||
|
||||
for (uint32_t d = dimension_remainder; d < dimension; d += offset) {
|
||||
#if defined(__x86_64__)
|
||||
p128 = _mm_loadu_pd(p + d);
|
||||
q128 = _mm_loadu_pd(q + d);
|
||||
sub = _mm_sub_pd(p128, q128);
|
||||
/* this clears out the sign bit of sub (thus computing its absolute value) */
|
||||
absolute_value = _mm_andnot_pd(sign_mask, sub);
|
||||
sum = _mm_add_pd(sum, absolute_value);
|
||||
#else // aarch64
|
||||
p128 = vld1q_f64(p + d);
|
||||
q128 = vld1q_f64(q + d);
|
||||
sub = vsubq_f64(p128, q128);
|
||||
/* this clears out the sign bit of sub - hopefully (thus computing its absolute value */
|
||||
absolute_value = vabsq_f64(sub);
|
||||
sum = vaddq_f64(sum, absolute_value);
|
||||
#endif
|
||||
}
|
||||
/*
|
||||
* in here we end up having a register with two terms that need to be added up to produce
|
||||
* the final answer, we first perform an horizontal add to reduce two to one term
|
||||
*/
|
||||
#if defined(__x86_64__)
|
||||
sum = _mm_hadd_pd(sum, zero);
|
||||
#else // aarch64
|
||||
sum = vpaddq_f64(sum, zero);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* we extract the remaining term [x,0] to produce the final solution
|
||||
*/
|
||||
#if defined(__x86_64__)
|
||||
_mm_storeu_pd(distance, sum);
|
||||
#else // aarch64
|
||||
vst1q_f64(distance, sum);
|
||||
#endif
|
||||
|
||||
if (dimension_remainder > 0) {
|
||||
/* d[0] = d[0] + distance_first_terms */
|
||||
twoSum(distance[0], distance_first_terms, distance, &local_distance_correction);
|
||||
global_distance_correction += local_distance_correction;
|
||||
}
|
||||
|
||||
return distance[0] + global_distance_correction;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
/*
|
||||
* Squared Euclidean (default)
|
||||
* We sum using cascaded summation
|
||||
* This version is unvectorized and is used in case vectorized instructions
|
||||
* are not available or for the the case that the dimension is not a multiple
|
||||
* of the width of the registers
|
||||
*/
|
||||
static force_inline double l2_squared_non_vectorized(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
double subtraction = 0.;
|
||||
double subtraction_correction = 0.;
|
||||
double term = 0.;
|
||||
double term_correction = 0.;
|
||||
double distance = 0.;
|
||||
double distance_correction = 0.;
|
||||
|
||||
twoDiff(q[0], p[0], &subtraction, &subtraction_correction);
|
||||
subtraction += subtraction_correction;
|
||||
square(subtraction, &term, &term_correction);
|
||||
term += term_correction;
|
||||
distance = term;
|
||||
|
||||
for (uint32_t d = 1; d < dimension; ++d) {
|
||||
twoDiff(q[d], p[d], &subtraction, &subtraction_correction);
|
||||
subtraction += subtraction_correction;
|
||||
square(subtraction, &term, &term_correction);
|
||||
term += term_correction;
|
||||
twoSum(distance, term, &distance, &term_correction);
|
||||
distance_correction += term_correction;
|
||||
}
|
||||
|
||||
return distance + distance_correction;
|
||||
}
|
||||
|
||||
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
|
||||
/*
|
||||
* Squared Euclidean (default)
|
||||
* This version is vectorized using SSE or NEON and is used in case only 128-bit
|
||||
* vectorized instructions are available
|
||||
*/
|
||||
static double l2_squared_128(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
if (unlikely(dimension == 0))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("L2 squared distance (128-bit): dimension must be larger than 0")));
|
||||
|
||||
double distance[2] = {0.};
|
||||
/* the result of dimension modulo 2 */
|
||||
uint32_t const dimension_remainder = dimension & 0x1U;
|
||||
uint32_t const offset = 2U;
|
||||
double distance_first_terms = 0.;
|
||||
double local_distance_correction = 0.;
|
||||
double global_distance_correction = 0.;
|
||||
|
||||
/*
|
||||
* this will compute the very first terms of the distance
|
||||
* (the ones that cannot be fully computed using simd registers, and
|
||||
* thus have to be done with scalar computations)
|
||||
*/
|
||||
if (dimension_remainder > 0) {
|
||||
/*
|
||||
* this is gonna compute the 1-dimensional distance (the very first
|
||||
* term in the whole distance computation)
|
||||
*/
|
||||
distance_first_terms = l2_squared_non_vectorized(p, q, dimension_remainder);
|
||||
}
|
||||
|
||||
/*
|
||||
* if the dimension is < 2 then the term above is the whole distance
|
||||
* otherwise we have at least one simd register we can fill
|
||||
*/
|
||||
if (unlikely(dimension < offset))
|
||||
return distance_first_terms;
|
||||
|
||||
#if defined(__x86_64__)
|
||||
__m128d const zero = _mm_setzero_pd();
|
||||
__m128d sum = zero;
|
||||
__m128d square;
|
||||
__m128d sub;
|
||||
__m128d p128, q128;
|
||||
#else // aarch64
|
||||
float64x2_t const zero = vdupq_n_f64(0);
|
||||
float64x2_t sum = zero;
|
||||
float64x2_t square;
|
||||
float64x2_t sub;
|
||||
float64x2_t p128, q128;
|
||||
#endif
|
||||
|
||||
Assert(((dimension - dimension_remainder) & 0x1U) == 0U);
|
||||
|
||||
for (uint32_t d = dimension_remainder; d < dimension; d += offset) {
|
||||
#if defined(__x86_64__)
|
||||
p128 = _mm_loadu_pd(p + d);
|
||||
q128 = _mm_loadu_pd(q + d);
|
||||
sub = _mm_sub_pd(p128, q128);
|
||||
square = _mm_mul_pd(sub, sub);
|
||||
sum = _mm_add_pd(sum, square);
|
||||
#else // aarch64
|
||||
p128 = vld1q_f64(p + d);
|
||||
q128 = vld1q_f64(q + d);
|
||||
sub = vsubq_f64(p128, q128);
|
||||
square = vmulq_f64(sub, sub);
|
||||
sum = vaddq_f64(sum, square);
|
||||
#endif
|
||||
}
|
||||
/*
|
||||
* in here we end up having a register with two terms that need to be added up to produce
|
||||
* the final answer, we first perform an horizontal add to reduce two to one term
|
||||
*/
|
||||
#if defined(__x86_64__)
|
||||
sum = _mm_hadd_pd(sum, zero);
|
||||
#else // aarch64
|
||||
sum = vpaddq_f64(sum, zero);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* we extract the remaining term [x,0] to produce the final solution
|
||||
*/
|
||||
#if defined(__x86_64__)
|
||||
_mm_storeu_pd(distance, sum);
|
||||
#else // aarch64
|
||||
vst1q_f64(distance, sum);
|
||||
#endif
|
||||
|
||||
if (dimension_remainder > 0) {
|
||||
/* d[0] = d[0] + distance_first_terms */
|
||||
twoSum(distance[0], distance_first_terms, distance, &local_distance_correction);
|
||||
global_distance_correction += local_distance_correction;
|
||||
}
|
||||
|
||||
return distance[0] + global_distance_correction;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
/*
|
||||
* L infinity distance (Chebyshev)
|
||||
* This version is unvectorized and is used in case vectorized instructions
|
||||
* are not available or for the the case that the dimension is not a multiple
|
||||
* of the width of the registers
|
||||
*/
|
||||
static force_inline double linf_non_vectorized(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
double distance = 0.;
|
||||
|
||||
double term = 0.;
|
||||
double term_correction = 0.;
|
||||
|
||||
twoDiff(q[0], p[0], &term, &term_correction);
|
||||
term += term_correction;
|
||||
// absolute value of the difference (hopefully done by clearing the sign bit)
|
||||
distance = std::abs(term);
|
||||
|
||||
for (uint32_t d = 1; d < dimension; ++d) {
|
||||
twoDiff(q[d], p[d], &term, &term_correction);
|
||||
term += term_correction;
|
||||
term = std::abs(term);
|
||||
distance = term > distance ? term : distance;
|
||||
}
|
||||
|
||||
return distance;
|
||||
}
|
||||
|
||||
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
|
||||
/*
|
||||
* L infinity distance (Chebyshev)
|
||||
* This version is vectorized using SSE or NEON and is used in case only 128-bit
|
||||
* vectorized instructions are available
|
||||
*/
|
||||
static double linf_128(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
if (unlikely(dimension == 0))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("L infinity distance (128-bit): dimension must be larger than 0")));
|
||||
|
||||
double distance[2] = {0.};
|
||||
/* the result of dimension modulo 2 */
|
||||
uint32_t const dimension_remainder = dimension & 0x1U;
|
||||
uint32_t const offset = 2U;
|
||||
double distance_first_terms = 0.;
|
||||
|
||||
/*
|
||||
* this will compute the very first terms of the distance
|
||||
* (the ones that cannot be fully computed using simd registers)
|
||||
*/
|
||||
if (dimension_remainder > 0)
|
||||
distance_first_terms = linf_non_vectorized(p, q, dimension_remainder);
|
||||
|
||||
/*
|
||||
* if the dimension is < 4 then the term above is the whole distance
|
||||
* otherwise we have at least one simd register we can fill
|
||||
*/
|
||||
if (unlikely(dimension < offset))
|
||||
return distance_first_terms;
|
||||
|
||||
#if defined(__x86_64__)
|
||||
__m128d const zero = _mm_setzero_pd();
|
||||
__m128d const sign_mask = _mm_set1_pd(-0.0f);
|
||||
__m128d max = zero;
|
||||
__m128d absolute_value;
|
||||
__m128d sub;
|
||||
__m128d p128, q128;
|
||||
#else // aarch64
|
||||
float64x2_t const zero = vdupq_n_f64(0);
|
||||
float64x2_t max = zero;
|
||||
float64x2_t absolute_value;
|
||||
float64x2_t sub;
|
||||
float64x2_t p128, q128;
|
||||
#endif
|
||||
|
||||
Assert(((dimension - dimension_remainder) & 0x1U) == 0U);
|
||||
|
||||
for (uint32_t d = dimension_remainder; d < dimension; d += offset) {
|
||||
#if defined(__x86_64__)
|
||||
p128 = _mm_loadu_pd(p + d);
|
||||
q128 = _mm_loadu_pd(q + d);
|
||||
sub = _mm_sub_pd(p128, q128);
|
||||
/* this clears out the sign bit of sub (thus computing its absolute value */
|
||||
absolute_value = _mm_andnot_pd(sign_mask, sub);
|
||||
max = _mm_max_pd(max, absolute_value);
|
||||
#else // aarch64
|
||||
p128 = vld1q_f64(p + d);
|
||||
q128 = vld1q_f64(q + d);
|
||||
sub = vsubq_f64(p128, q128);
|
||||
/* this clears out the sign bit of sub - hopefully (thus computing its absolute value */
|
||||
absolute_value = vabsq_f64(sub);
|
||||
max = vmaxq_f64(max, absolute_value);
|
||||
#endif
|
||||
}
|
||||
/*
|
||||
* in here we end up having a register with two terms, from which we extract the max
|
||||
* to produce the final answer
|
||||
*/
|
||||
#if defined(__x86_64__)
|
||||
_mm_storeu_pd(distance, max);
|
||||
#else // aarch64
|
||||
vst1q_f64(distance, max);
|
||||
#endif
|
||||
|
||||
double result = distance_first_terms;
|
||||
for (uint32_t m = 0; m < offset; ++m)
|
||||
result = distance[m] > result ? distance[m] : result;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
/*
|
||||
* L1 distance (Manhattan)
|
||||
* This is the main function. It will be automatically vectorized
|
||||
* if possible
|
||||
*/
|
||||
double l1(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
if (unlikely(dimension == 0))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("L1 distance: dimension must be larger than 0")));
|
||||
/*
|
||||
* depending on the feature of the underlying processor we vectorized one way
|
||||
* or another. in the worst case we do not vectorized at all
|
||||
*/
|
||||
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
|
||||
return l1_128(p, q, dimension);
|
||||
#else
|
||||
return l1_non_vectorized(p, q, dimension);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Squared Euclidean (default)
|
||||
* This is the main function. It will be automatically vectorized
|
||||
* if possible
|
||||
*/
|
||||
double l2_squared(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
if (unlikely(dimension == 0))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("L2 squared distance: dimension must be larger than 0")));
|
||||
/*
|
||||
* depending on the feature of the underlying processor we vectorized one way
|
||||
* or another. in the worst case we do not vectorized at all
|
||||
*/
|
||||
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
|
||||
return l2_squared_128(p, q, dimension);
|
||||
#else
|
||||
return l2_squared_non_vectorized(p, q, dimension);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* L2 distance (Euclidean)
|
||||
* This is the main function. It will be automatically vectorized
|
||||
* if possible
|
||||
*/
|
||||
double l2(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
if (unlikely(dimension == 0))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("L2 distance: dimension must be larger than 0")));
|
||||
|
||||
/*
|
||||
* this one is vectorized automatically (or not)
|
||||
*/
|
||||
double const l2_sq = l2_squared(p, q, dimension);
|
||||
|
||||
// we can replace this with exact computation via mpfr (more costly, but the best alternative
|
||||
// for fixed precision)
|
||||
return std::sqrt(l2_sq);
|
||||
}
|
||||
|
||||
/*
|
||||
* L infinity distance (Chebyshev)
|
||||
* This is the main function. It will be automatically vectorized
|
||||
* if possible
|
||||
*/
|
||||
double linf(double const * p, double const * q, uint32_t const dimension)
|
||||
{
|
||||
if (unlikely(dimension == 0))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("L infinity distance: dimension must be larger than 0")));
|
||||
/*
|
||||
* depending on the feature of the underlying processor we vectorized one way
|
||||
* or another. in the worst case we do not vectorized at all
|
||||
*/
|
||||
#if (defined(__x86_64__) && defined(__SSE3__)) || (defined(__aarch64__) && defined(__ARM_NEON))
|
||||
return linf_128(p, q, dimension);
|
||||
#else
|
||||
return linf_non_vectorized(p, q, dimension);
|
||||
#endif
|
||||
}
|
|
@ -0,0 +1,256 @@
|
|||
/* *
|
||||
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
|
||||
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
---------------------------------------------------------------------------------------
|
||||
|
||||
fp_ops.cpp
|
||||
Robust floating point operations
|
||||
|
||||
IDENTIFICATION
|
||||
src/gausskernel/dbmind/db4ai/executor/fp_ops.cpp
|
||||
|
||||
---------------------------------------------------------------------------------------
|
||||
* */
|
||||
|
||||
#include "nodes/execnodes.h"
|
||||
|
||||
#include "db4ai/fp_ops.h"
|
||||
#include "db4ai/db4ai_cpu.h"
|
||||
|
||||
#define WITH_ROBUST_OPS
|
||||
|
||||
/*
|
||||
* Knuth's high precision sum a + b (Shewchuk's version)
|
||||
*/
|
||||
void twoSum(double const a, double const b, double *sum, double *e)
|
||||
{
|
||||
*sum = a + b;
|
||||
#ifdef WITH_ROBUST_OPS
|
||||
double const bv = *sum - a;
|
||||
double const av = *sum - bv;
|
||||
double const be = b - bv;
|
||||
double const ae = a - av;
|
||||
*e = ae + be;
|
||||
#else
|
||||
*e = 0.;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* The equivalent subtraction a - b (Shewchuk's version)
|
||||
*/
|
||||
void twoDiff(double const a, double const b, double *sub, double *e)
|
||||
{
|
||||
*sub = a - b;
|
||||
#ifdef WITH_ROBUST_OPS
|
||||
double const bv = a - *sub;
|
||||
double const av = *sub + bv;
|
||||
double const be = bv - b;
|
||||
double const ae = a - av;
|
||||
*e = ae + be;
|
||||
#else
|
||||
*e = 0.;
|
||||
#endif
|
||||
}
|
||||
|
||||
static force_inline void veltkamp_split(double p, double *p_hi, double *p_low)
|
||||
{
|
||||
uint32_t const shift = 27U; // ceil(53 / 2)
|
||||
uint32_t c = ((1U << shift) + 1U) * p;
|
||||
double p_big = c - p;
|
||||
*p_hi = c - p_big;
|
||||
*p_low = p - *p_hi;
|
||||
}
|
||||
|
||||
/*
|
||||
* High precision product a * b (Shewchuk's version of Dekker-Veltkamp)
|
||||
*/
|
||||
void twoMult(double const a, double const b, double *mult, double *e)
|
||||
{
|
||||
*mult = a * b;
|
||||
#ifdef WITH_ROBUST_OPS
|
||||
double a_hi, a_low;
|
||||
double b_hi, b_low;
|
||||
double e_1, e_2, e_3;
|
||||
veltkamp_split(a, &a_hi, &a_low);
|
||||
veltkamp_split(b, &b_hi, &b_low);
|
||||
e_1 = *mult - (a_hi * b_hi);
|
||||
e_2 = e_1 - (a_low * b_hi);
|
||||
e_3 = e_2 - (a_hi * b_low);
|
||||
*e = (a_low * b_low) - e_3;
|
||||
#else
|
||||
*e = 0.;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* High precision square a * a (Shewchuk's version of Dekker-Veltkamp)
|
||||
*/
|
||||
void square(double a, double *square, double *e)
|
||||
{
|
||||
*square = a * a;
|
||||
#ifdef WITH_ROBUST_OPS
|
||||
double a_hi, a_low;
|
||||
double e_1, e_3;
|
||||
veltkamp_split(a, &a_hi, &a_low);
|
||||
e_1 = *square - (a_hi * a_hi);
|
||||
e_3 = e_1 - (a_low * (a_hi + a_hi));
|
||||
*e = (a_low * a_low) - e_3;
|
||||
#else
|
||||
*e = 0.;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* High precision division a / b = *div + *e (Shewchuk's version of Dekker-Veltkamp)
|
||||
*/
|
||||
void twoDiv(double const a, double const b, double *div, double *e)
|
||||
{
|
||||
#ifdef WITH_ROBUST_OPS
|
||||
twoMult(a, 1 / b, div, e);
|
||||
#else
|
||||
*div = a / b;
|
||||
*e = 0.;
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Implementing the methods of the incremental statistics
|
||||
*/
|
||||
IncrementalStatistics IncrementalStatistics::operator + (IncrementalStatistics const & rhs) const
|
||||
{
|
||||
IncrementalStatistics sum = *this;
|
||||
sum += rhs;
|
||||
return sum;
|
||||
}
|
||||
|
||||
IncrementalStatistics IncrementalStatistics::operator - (IncrementalStatistics const & rhs) const
|
||||
{
|
||||
IncrementalStatistics minus = *this;
|
||||
minus -= rhs;
|
||||
return minus;
|
||||
}
|
||||
|
||||
IncrementalStatistics &IncrementalStatistics::operator += (IncrementalStatistics const & rhs)
|
||||
{
|
||||
// for the term corresponding to the variance, the terms add up but there is a correcting term to
|
||||
// be added up as well. The proof of this can be found in the updated version of the HLD
|
||||
double const total_r = rhs.total;
|
||||
uint64_t const rhs_population = rhs.population;
|
||||
double increment_s = 0.;
|
||||
bool execute = (rhs_population > 0) && (population > 0);
|
||||
|
||||
// this branch is as easy as it can get because population is mostly > 0
|
||||
if (likely(execute)) {
|
||||
increment_s = ((total_r * total_r) / rhs_population) + ((total * total) / population) -
|
||||
(((total_r + total) * (total_r + total)) / (rhs_population + population));
|
||||
}
|
||||
|
||||
total += rhs.total;
|
||||
population += rhs.population;
|
||||
|
||||
s += (rhs.s + increment_s);
|
||||
|
||||
max_value = rhs.max_value > max_value ? rhs.max_value : max_value;
|
||||
min_value = rhs.min_value < min_value ? rhs.min_value : min_value;
|
||||
return *this;
|
||||
}
|
||||
|
||||
IncrementalStatistics &IncrementalStatistics::operator -= (IncrementalStatistics const & rhs)
|
||||
{
|
||||
uint64_t const current_population = population - rhs.population;
|
||||
double const current_total = total - rhs.total;
|
||||
double decrement_s = 0.;
|
||||
bool const execute = (population > 0) && (current_population > 0) && (rhs.population > 0);
|
||||
|
||||
// this branch is as easy as it can get because population is mostly > 0
|
||||
if (likely(execute)) {
|
||||
decrement_s = ((current_total * current_total) / current_population) +
|
||||
((rhs.total * rhs.total) / rhs.population) - ((total * total) / population);
|
||||
}
|
||||
|
||||
total = current_total;
|
||||
population = current_population;
|
||||
|
||||
s -= (rhs.s + decrement_s);
|
||||
s = std::abs(s);
|
||||
return *this;
|
||||
}
|
||||
|
||||
double IncrementalStatistics::getMin() const
|
||||
{
|
||||
return min_value;
|
||||
}
|
||||
|
||||
void IncrementalStatistics::setMin(double min)
|
||||
{
|
||||
min_value = min;
|
||||
}
|
||||
|
||||
double IncrementalStatistics::getMax() const
|
||||
{
|
||||
return max_value;
|
||||
}
|
||||
|
||||
void IncrementalStatistics::setMax(double max)
|
||||
{
|
||||
max_value = max;
|
||||
}
|
||||
|
||||
double IncrementalStatistics::getTotal() const
|
||||
{
|
||||
return total;
|
||||
}
|
||||
|
||||
void IncrementalStatistics::setTotal(double t)
|
||||
{
|
||||
total = t;
|
||||
}
|
||||
|
||||
uint64_t IncrementalStatistics::getPopulation() const
|
||||
{
|
||||
return population;
|
||||
}
|
||||
|
||||
void IncrementalStatistics::setPopulation(uint64_t pop)
|
||||
{
|
||||
population = pop;
|
||||
}
|
||||
|
||||
double IncrementalStatistics::getEmpiricalMean() const
|
||||
{
|
||||
double mean = 0.;
|
||||
if (likely(population > 0))
|
||||
mean = total / population;
|
||||
return mean;
|
||||
}
|
||||
|
||||
double IncrementalStatistics::getEmpiricalVariance() const
|
||||
{
|
||||
double variance = 0.;
|
||||
if (likely(population > 0))
|
||||
variance = s / population;
|
||||
return variance;
|
||||
}
|
||||
|
||||
double IncrementalStatistics::getEmpiricalStdDev() const
|
||||
{
|
||||
return std::sqrt(getEmpiricalVariance());
|
||||
}
|
||||
|
||||
bool IncrementalStatistics::reset()
|
||||
{
|
||||
memset_s(this, sizeof(IncrementalStatistics), 0, sizeof(IncrementalStatistics));
|
||||
min_value = DBL_MAX;
|
||||
return true;
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
# gd.cmake
|
||||
set(TGT_gd_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linregr.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logregr.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matrix.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/optimizer_gd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/optimizer_ngd.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/predict.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/shuffle_cache.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/svm.cpp
|
||||
)
|
||||
set(TGT_gd_INC
|
||||
${PROJECT_OPENGS_DIR}/contrib/log_fdw
|
||||
${PROJECT_TRUNK_DIR}/distribute/bin/gds
|
||||
${PROJECT_SRC_DIR}/include/libcomm
|
||||
${PROJECT_SRC_DIR}/include
|
||||
${PROJECT_SRC_DIR}/lib/gstrace
|
||||
${LZ4_INCLUDE_PATH}
|
||||
${LIBCGROUP_INCLUDE_PATH}
|
||||
${LIBORC_INCLUDE_PATH}
|
||||
${EVENT_INCLUDE_PATH}
|
||||
${PROTOBUF_INCLUDE_PATH}
|
||||
${ZLIB_INCLUDE_PATH}
|
||||
)
|
||||
set(gd_DEF_OPTIONS ${MACRO_OPTIONS})
|
||||
set(gd_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS} -std=c++14 -fPIE)
|
||||
list(REMOVE_ITEM gd_COMPILE_OPTIONS -fPIC)
|
||||
set(gd_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
|
||||
add_static_objtarget(gausskernel_db4ai_executor_gd TGT_gd_SRC TGT_gd_INC "${gd_DEF_OPTIONS}" "${gd_COMPILE_OPTIONS}" "${gd_LINK_OPTIONS}")
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#---------------------------------------------------------------------------------------
|
||||
#
|
||||
# IDENTIFICATION
|
||||
# src/gausskernel/dbmind/db4ai/executor/Makefile
|
||||
#
|
||||
# ---------------------------------------------------------------------------------------
|
||||
|
||||
subdir = src/gausskernel/dbmind/db4ai/executor/gd
|
||||
top_builddir = ../../../../../..
|
||||
include $(top_builddir)/src/Makefile.global
|
||||
|
||||
ifneq "$(MAKECMDGOALS)" "clean"
|
||||
ifneq "$(MAKECMDGOALS)" "distclean"
|
||||
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
|
||||
-include $(DEPEND)
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
OBJS = gd.o matrix.o optimizer_gd.o optimizer_ngd.o shuffle_cache.o logregr.o svm.o linregr.o predict.o
|
||||
|
||||
include $(top_srcdir)/src/gausskernel/common.mk
|
|
@ -0,0 +1,357 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* gd.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/gd.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "postgres.h"
|
||||
#include "utils/builtins.h"
|
||||
#include "nodes/makefuncs.h"
|
||||
|
||||
#include "db4ai/gd.h"
|
||||
#include "nodes/primnodes.h"
|
||||
#include "utils/array.h"
|
||||
|
||||
const char *gd_get_optimizer_name(OptimizerML optimizer)
|
||||
{
|
||||
static const char* names[] = { "gd", "ngd" };
|
||||
if (optimizer > OPTIMIZER_NGD)
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Invalid optimizer %d", optimizer)));
|
||||
|
||||
return names[optimizer];
|
||||
}
|
||||
|
||||
const char *gd_get_expr_name(GradientDescentExprField field)
|
||||
{
|
||||
const char* names[] = {
|
||||
"algorithm",
|
||||
"optimizer",
|
||||
"result_type",
|
||||
"num_iterations",
|
||||
"exec_time_msecs",
|
||||
"processed",
|
||||
"discarded",
|
||||
"weights",
|
||||
"categories",
|
||||
};
|
||||
|
||||
if ((field & GD_EXPR_SCORE) != 0)
|
||||
return gd_get_metric_name(field & ~GD_EXPR_SCORE);
|
||||
|
||||
if (field > GD_EXPR_CATEGORIES)
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Invalid GD expression field %d", field)));
|
||||
|
||||
return names[field];
|
||||
}
|
||||
|
||||
Datum gd_float_get_datum(Oid type, gd_float value)
|
||||
{
|
||||
Datum datum = 0;
|
||||
switch (type) {
|
||||
case BOOLOID:
|
||||
datum = BoolGetDatum(value != 0.0);
|
||||
break;
|
||||
case INT1OID:
|
||||
datum = Int8GetDatum(value);
|
||||
break;
|
||||
case INT2OID:
|
||||
datum = Int16GetDatum(value);
|
||||
break;
|
||||
case INT4OID:
|
||||
datum = Int32GetDatum(value);
|
||||
break;
|
||||
case INT8OID:
|
||||
datum = Int64GetDatum(value);
|
||||
break;
|
||||
case FLOAT4OID:
|
||||
datum = Float4GetDatum(value);
|
||||
break;
|
||||
case FLOAT8OID:
|
||||
datum = Float8GetDatum(value);
|
||||
break;
|
||||
case NUMERICOID:
|
||||
datum = DirectFunctionCall1(float4_numeric, Float4GetDatum(value));
|
||||
break;
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Oid type %d not yet supported", type)));
|
||||
break;
|
||||
}
|
||||
return datum;
|
||||
}
|
||||
|
||||
gd_float gd_datum_get_float(Oid type, Datum datum)
|
||||
{
|
||||
gd_float value = 0;
|
||||
switch (type) {
|
||||
case BOOLOID:
|
||||
value = DatumGetBool(datum) ? 1.0 : 0.0;
|
||||
break;
|
||||
case INT1OID:
|
||||
value = DatumGetInt8(datum);
|
||||
break;
|
||||
case INT2OID:
|
||||
value = DatumGetInt16(datum);
|
||||
break;
|
||||
case INT4OID:
|
||||
value = DatumGetInt32(datum);
|
||||
break;
|
||||
case INT8OID:
|
||||
value = DatumGetInt64(datum);
|
||||
break;
|
||||
case FLOAT4OID:
|
||||
value = DatumGetFloat4(datum);
|
||||
break;
|
||||
case FLOAT8OID:
|
||||
value = DatumGetFloat8(datum);
|
||||
break;
|
||||
case NUMERICOID:
|
||||
value = DatumGetFloat8(DirectFunctionCall1(numeric_float8_no_overflow, datum));
|
||||
break;
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Oid type %d not yet supported", type)));
|
||||
break;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
char *gd_get_metric_name(int metric)
|
||||
{
|
||||
switch (metric) {
|
||||
case METRIC_ACCURACY:
|
||||
return "accuracy";
|
||||
case METRIC_F1:
|
||||
return "f1";
|
||||
case METRIC_PRECISION:
|
||||
return "precision";
|
||||
case METRIC_RECALL:
|
||||
return "recall";
|
||||
case METRIC_LOSS:
|
||||
return "loss";
|
||||
case METRIC_MSE:
|
||||
return "mse";
|
||||
default:
|
||||
Assert(false);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
extern GradientDescentAlgorithm gd_logistic_regression;
|
||||
extern GradientDescentAlgorithm gd_svm_classification;
|
||||
extern GradientDescentAlgorithm gd_linear_regression;
|
||||
|
||||
GradientDescentAlgorithm *gd_get_algorithm(AlgorithmML algorithm)
|
||||
{
|
||||
GradientDescentAlgorithm *gd_algorithm = nullptr;
|
||||
switch (algorithm) {
|
||||
case LOGISTIC_REGRESSION:
|
||||
gd_algorithm = &gd_logistic_regression;
|
||||
break;
|
||||
case SVM_CLASSIFICATION:
|
||||
gd_algorithm = &gd_svm_classification;
|
||||
break;
|
||||
case LINEAR_REGRESSION:
|
||||
gd_algorithm = &gd_linear_regression;
|
||||
break;
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Invalid algorithm %d", algorithm)));
|
||||
break;
|
||||
}
|
||||
return gd_algorithm;
|
||||
}
|
||||
|
||||
// ////////////////////////////////////////////////////////////////////////////
|
||||
// expressions for projections
|
||||
|
||||
static struct {
|
||||
GradientDescentExprField field;
|
||||
Oid fieldtype;
|
||||
char *name;
|
||||
} GradientDescentExpr_fields[] = {
|
||||
{ GD_EXPR_ALGORITHM, INT4OID, "algorithm"},
|
||||
{ GD_EXPR_OPTIMIZER, INT4OID, "optimizer"},
|
||||
{ GD_EXPR_RESULT_TYPE, OIDOID, "result_type"},
|
||||
{ GD_EXPR_NUM_ITERATIONS, INT4OID, "num_iterations"},
|
||||
{ GD_EXPR_EXEC_TIME_MSECS, FLOAT4OID, "exec_time_msecs"},
|
||||
{ GD_EXPR_PROCESSED_TUPLES, INT4OID, "processed_tuples"},
|
||||
{ GD_EXPR_DISCARDED_TUPLES, INT4OID, "discarded_tuples"},
|
||||
{ GD_EXPR_WEIGHTS, FLOAT4ARRAYOID, "weights"},
|
||||
};
|
||||
|
||||
static GradientDescentExpr *makeGradientDescentExpr(GradientDescentExprField field, Oid fieldtype)
|
||||
{
|
||||
GradientDescentExpr *xpr = makeNode(GradientDescentExpr);
|
||||
xpr->field = field;
|
||||
xpr->fieldtype = fieldtype;
|
||||
return xpr;
|
||||
}
|
||||
|
||||
List *makeGradientDescentExpr(AlgorithmML algorithm, List *list, int field)
|
||||
{
|
||||
Expr *expr;
|
||||
for (size_t i = 0; i < sizeof(GradientDescentExpr_fields) / sizeof(GradientDescentExpr_fields[0]); i++) {
|
||||
expr = (Expr *)makeGradientDescentExpr(GradientDescentExpr_fields[i].field,
|
||||
GradientDescentExpr_fields[i].fieldtype);
|
||||
list = lappend(list, makeTargetEntry(expr, field++, GradientDescentExpr_fields[i].name, false));
|
||||
}
|
||||
|
||||
// add metrics
|
||||
GradientDescentAlgorithm *palgo = gd_get_algorithm(algorithm);
|
||||
int metrics = palgo->metrics;
|
||||
int metric = 1;
|
||||
while (metrics != 0) {
|
||||
if (metrics & metric) {
|
||||
expr = (Expr *)makeGradientDescentExpr(makeGradientDescentExprFieldScore(metric), FLOAT4OID);
|
||||
list = lappend(list, makeTargetEntry(expr, field++, gd_get_metric_name(metric), false));
|
||||
metrics &= ~metric;
|
||||
}
|
||||
metric <<= 1;
|
||||
}
|
||||
|
||||
// binary value mappings
|
||||
if (dep_var_is_binary(palgo)) {
|
||||
expr = (Expr *)makeGradientDescentExpr(GD_EXPR_CATEGORIES, TEXTARRAYOID);
|
||||
list = lappend(list, makeTargetEntry(expr, field++, "categories", false));
|
||||
}
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
Datum ExecGDExprScore(GradientDescentExprState *mlstate, bool *isNull)
|
||||
{
|
||||
Datum dt = 0;
|
||||
bool hasp, hasr;
|
||||
double precision, recall;
|
||||
GradientDescentState *gd_state = (GradientDescentState *)mlstate->ps;
|
||||
|
||||
switch (mlstate->xpr->field & ~GD_EXPR_SCORE) {
|
||||
case METRIC_LOSS:
|
||||
dt = Float4GetDatum(gd_state->loss);
|
||||
break;
|
||||
case METRIC_ACCURACY:
|
||||
dt = Float4GetDatum(get_accuracy(&gd_state->scores));
|
||||
break;
|
||||
case METRIC_F1: // 2 * (precision * recall) / (precision + recall)
|
||||
precision = get_precision(&gd_state->scores, &hasp);
|
||||
recall = get_recall(&gd_state->scores, &hasr);
|
||||
if ((hasp && precision > 0) || (hasr && recall > 0)) {
|
||||
dt = Float4GetDatum(2.0 * precision * recall / (precision + recall));
|
||||
} else
|
||||
*isNull = true;
|
||||
break;
|
||||
case METRIC_PRECISION:
|
||||
precision = get_precision(&gd_state->scores, &hasp);
|
||||
if (hasp) {
|
||||
dt = Float4GetDatum(precision);
|
||||
} else
|
||||
*isNull = true;
|
||||
break;
|
||||
case METRIC_RECALL:
|
||||
recall = get_recall(&gd_state->scores, &hasr);
|
||||
if (hasr) {
|
||||
dt = Float4GetDatum(recall);
|
||||
} else
|
||||
*isNull = true;
|
||||
break;
|
||||
case METRIC_MSE:
|
||||
dt = Float4GetDatum(gd_state->scores.mse);
|
||||
break;
|
||||
default:
|
||||
*isNull = true;
|
||||
break;
|
||||
}
|
||||
|
||||
return dt;
|
||||
}
|
||||
|
||||
Datum ExecNonGDExprScore(GradientDescentExprState *mlstate, ExprContext *econtext, bool *isNull)
|
||||
{
|
||||
Datum dt = 0;
|
||||
Oid typoutput;
|
||||
GradientDescentState *gd_state = (GradientDescentState *)mlstate->ps;
|
||||
GradientDescent *gd_node = (GradientDescent *)gd_state->ss.ps.plan;
|
||||
ArrayBuildState *astate = NULL;
|
||||
|
||||
switch (mlstate->xpr->field) {
|
||||
case GD_EXPR_ALGORITHM:
|
||||
dt = Int32GetDatum(gd_node->algorithm);
|
||||
break;
|
||||
case GD_EXPR_OPTIMIZER:
|
||||
dt = Int32GetDatum(gd_node->optimizer);
|
||||
break;
|
||||
case GD_EXPR_RESULT_TYPE:
|
||||
typoutput = get_atttypid(gd_state, gd_node->targetcol);
|
||||
dt = UInt32GetDatum(typoutput);
|
||||
break;
|
||||
case GD_EXPR_NUM_ITERATIONS:
|
||||
dt = Int32GetDatum(gd_state->n_iterations);
|
||||
break;
|
||||
case GD_EXPR_EXEC_TIME_MSECS:
|
||||
dt = Float4GetDatum(gd_state->usecs / 1000.0);
|
||||
break;
|
||||
case GD_EXPR_PROCESSED_TUPLES:
|
||||
dt = Int32GetDatum(gd_state->processed);
|
||||
break;
|
||||
case GD_EXPR_DISCARDED_TUPLES:
|
||||
dt = Int32GetDatum(gd_state->discarded);
|
||||
break;
|
||||
case GD_EXPR_WEIGHTS: {
|
||||
gd_float *pw = gd_state->weights.data;
|
||||
for (int i = 0; i < gd_state->weights.rows; i++)
|
||||
astate = accumArrayResult(astate, Float4GetDatum(*pw++), false, FLOAT4OID, CurrentMemoryContext);
|
||||
dt = makeArrayResult(astate, econtext->ecxt_per_query_memory);
|
||||
} break;
|
||||
case GD_EXPR_CATEGORIES:
|
||||
typoutput = get_atttypid(gd_state, gd_node->targetcol);
|
||||
for (int i = 0; i < gd_state->num_classes; i++)
|
||||
astate =
|
||||
accumArrayResult(astate, gd_state->binary_classes[i], false, typoutput, CurrentMemoryContext);
|
||||
dt = makeArrayResult(astate, econtext->ecxt_per_query_memory);
|
||||
break;
|
||||
default:
|
||||
*isNull = true;
|
||||
break;
|
||||
}
|
||||
|
||||
return dt;
|
||||
}
|
||||
|
||||
Datum ExecEvalGradientDescent(GradientDescentExprState *mlstate, ExprContext *econtext, bool *isNull,
|
||||
ExprDoneCond *isDone)
|
||||
{
|
||||
Datum dt = 0;
|
||||
|
||||
if (isDone != NULL)
|
||||
*isDone = ExprSingleResult;
|
||||
|
||||
*isNull = false;
|
||||
if ((mlstate->xpr->field & GD_EXPR_SCORE) != 0) {
|
||||
dt = ExecGDExprScore(mlstate, isNull);
|
||||
} else {
|
||||
dt = ExecNonGDExprScore(mlstate, econtext, isNull);
|
||||
}
|
||||
|
||||
return dt;
|
||||
}
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* linreg.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/linreg.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
static void linear_reg_gradients(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
|
||||
Matrix *weights, Matrix *gradients)
|
||||
{
|
||||
Assert(features->rows > 0);
|
||||
|
||||
// xT * (x * w - y)
|
||||
Matrix loss;
|
||||
matrix_init(&loss, features->rows);
|
||||
matrix_mult_vector(features, weights, &loss);
|
||||
matrix_subtract(&loss, dep_var);
|
||||
|
||||
Matrix x_t;
|
||||
matrix_init_transpose(&x_t, features);
|
||||
matrix_mult_vector(&x_t, &loss, gradients);
|
||||
|
||||
matrix_release(&x_t);
|
||||
matrix_release(&loss);
|
||||
}
|
||||
|
||||
static double linear_reg_test(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
|
||||
const Matrix *weights, Scores *scores)
|
||||
{
|
||||
Assert(features->rows > 0);
|
||||
|
||||
// loss = sum((x * w - y)^2) / 2m
|
||||
Matrix errors;
|
||||
matrix_init(&errors, features->rows);
|
||||
matrix_mult_vector(features, weights, &errors);
|
||||
matrix_subtract(&errors, dep_var);
|
||||
matrix_square(&errors);
|
||||
gd_float tuple_loss = matrix_get_sum(&errors) / features->rows;
|
||||
scores->mse += tuple_loss;
|
||||
tuple_loss /= 2;
|
||||
|
||||
matrix_release(&errors);
|
||||
|
||||
return tuple_loss;
|
||||
}
|
||||
|
||||
static gd_float linear_reg_predict(const Matrix *features, const Matrix *weights)
|
||||
{
|
||||
// p = x * w
|
||||
return matrix_dot(features, weights);
|
||||
}
|
||||
|
||||
GradientDescentAlgorithm gd_linear_regression = {
|
||||
"linear-regression",
|
||||
GD_DEPENDENT_VAR_CONTINUOUS,
|
||||
METRIC_MSE, // same as loss function
|
||||
0.0,
|
||||
0.0, // not necessary
|
||||
linear_reg_gradients,
|
||||
linear_reg_test,
|
||||
linear_reg_predict,
|
||||
};
|
|
@ -0,0 +1,108 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* logreg.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/logreg.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
static void logreg_gradients(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
|
||||
Matrix *weights, Matrix *gradients)
|
||||
{
|
||||
Assert(features->rows > 0);
|
||||
|
||||
// xT * ((1.0 / (1.0 + exp(-x*w))) - y)
|
||||
Matrix sigma;
|
||||
matrix_init(&sigma, features->rows);
|
||||
matrix_mult_vector(features, weights, &sigma);
|
||||
matrix_sigmoid(&sigma);
|
||||
matrix_subtract(&sigma, dep_var);
|
||||
|
||||
Matrix x_t;
|
||||
matrix_init_transpose(&x_t, features);
|
||||
matrix_mult_vector(&x_t, &sigma, gradients);
|
||||
|
||||
matrix_release(&x_t);
|
||||
matrix_release(&sigma);
|
||||
}
|
||||
|
||||
static double logreg_test(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
|
||||
const Matrix *weights, Scores *scores)
|
||||
{
|
||||
Assert(features->rows > 0);
|
||||
|
||||
Matrix predictions;
|
||||
Matrix cost1;
|
||||
Matrix cost2;
|
||||
Matrix classification;
|
||||
|
||||
// p = 1.0 + exp(-x*w)
|
||||
matrix_init(&predictions, features->rows);
|
||||
matrix_mult_vector(features, weights, &predictions);
|
||||
matrix_sigmoid(&predictions);
|
||||
|
||||
// compute relevance
|
||||
matrix_init_clone(&classification, &predictions);
|
||||
matrix_binary(&classification, 0.5, 0.0, 1.0);
|
||||
matrix_relevance(&classification, dep_var, scores, 1.0);
|
||||
matrix_release(&classification);
|
||||
|
||||
// compute loss using cross-entropy
|
||||
// cost1 = y * log(p)
|
||||
matrix_init_clone(&cost1, &predictions);
|
||||
matrix_log(&cost1);
|
||||
matrix_mult_entrywise(&cost1, dep_var);
|
||||
|
||||
// cost2 = (1-y) * log(1-p)
|
||||
matrix_complement(&predictions);
|
||||
matrix_log(&predictions);
|
||||
matrix_init_clone(&cost2, dep_var);
|
||||
matrix_complement(&cost2);
|
||||
matrix_mult_entrywise(&cost2, &predictions);
|
||||
|
||||
// cost = sum(-cost1 - cost2) / N
|
||||
matrix_negate(&cost1);
|
||||
matrix_subtract(&cost1, &cost2);
|
||||
gd_float tuple_loss = matrix_get_sum(&cost1) / features->rows;
|
||||
matrix_release(&cost2);
|
||||
matrix_release(&cost1);
|
||||
|
||||
matrix_release(&predictions);
|
||||
return tuple_loss;
|
||||
}
|
||||
|
||||
static gd_float logreg_predict(const Matrix *features, const Matrix *weights)
|
||||
{
|
||||
// p = 1.0 + exp(-x*w)
|
||||
gd_float r = matrix_dot(features, weights);
|
||||
r = 1.0 / (1.0 + exp(-r));
|
||||
return r < 0.5 ? 0.0 : 1.0;
|
||||
}
|
||||
|
||||
GradientDescentAlgorithm gd_logistic_regression = {
|
||||
"logistic-regression",
|
||||
GD_DEPENDENT_VAR_BINARY,
|
||||
METRIC_ACCURACY | METRIC_F1 | METRIC_PRECISION | METRIC_RECALL | METRIC_LOSS,
|
||||
0.0,
|
||||
1.0,
|
||||
logreg_gradients,
|
||||
logreg_test,
|
||||
logreg_predict,
|
||||
};
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* matrix.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/matrix.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/matrix.h"
|
||||
|
||||
#define MATRIX_LIMITED_OUTPUT 30
|
||||
|
||||
void matrix_print(const Matrix *matrix, StringInfo buf, bool full)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(!matrix->transposed);
|
||||
const gd_float *pf = matrix->data;
|
||||
appendStringInfoChar(buf, '[');
|
||||
for (int r = 0; r < matrix->rows; r++) {
|
||||
if (!full && matrix->rows > MATRIX_LIMITED_OUTPUT && r > (MATRIX_LIMITED_OUTPUT / 2) &&
|
||||
r < matrix->rows - (MATRIX_LIMITED_OUTPUT / 2)) {
|
||||
if (matrix->columns > 1)
|
||||
appendStringInfoString(buf, ",\n...");
|
||||
else
|
||||
appendStringInfoString(buf, ", ...");
|
||||
|
||||
r = matrix->rows - MATRIX_LIMITED_OUTPUT / 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (matrix->columns > 1) {
|
||||
if (r > 0)
|
||||
appendStringInfoString(buf, ",\n");
|
||||
|
||||
appendStringInfoChar(buf, '[');
|
||||
} else {
|
||||
if (r > 0)
|
||||
appendStringInfoString(buf, ", ");
|
||||
}
|
||||
for (int c = 0; c < matrix->columns; c++) {
|
||||
if (c > 0)
|
||||
appendStringInfoString(buf, ", ");
|
||||
|
||||
appendStringInfo(buf, "%.16g", *pf++);
|
||||
}
|
||||
if (matrix->columns > 1)
|
||||
appendStringInfoChar(buf, ']');
|
||||
}
|
||||
appendStringInfoChar(buf, ']');
|
||||
}
|
||||
|
||||
void elog_matrix(int elevel, const char *msg, const Matrix *matrix)
|
||||
{
|
||||
StringInfoData buf;
|
||||
initStringInfo(&buf);
|
||||
matrix_print(matrix, &buf, false);
|
||||
ereport(elevel, (errmodule(MOD_DB4AI), errmsg("%s = %s", msg, buf.data)));
|
||||
pfree(buf.data);
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* optimizer_gd.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/optimizer_gd.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
// ////////////////////////////////////////////////////////////////////////
|
||||
// gd: minibatch basic optimizer
|
||||
|
||||
typedef struct OptimizerMinibatch {
|
||||
OptimizerGD opt;
|
||||
const GradientDescentState *gd_state;
|
||||
double learning_rate;
|
||||
} OptimizerMinibatch;
|
||||
|
||||
static void opt_gd_end_iteration(OptimizerGD *optimizer)
|
||||
{
|
||||
OptimizerMinibatch *opt = (OptimizerMinibatch *)optimizer;
|
||||
|
||||
// decay the learning rate with decay^iterations
|
||||
opt->learning_rate *= gd_get_node(opt->gd_state)->decay;
|
||||
}
|
||||
|
||||
static void opt_gd_update_batch(OptimizerGD *optimizer, const Matrix *features, const Matrix *dep_var)
|
||||
{
|
||||
OptimizerMinibatch *opt = (OptimizerMinibatch *)optimizer;
|
||||
|
||||
// clear gradients of the batch
|
||||
matrix_zeroes(&optimizer->gradients);
|
||||
|
||||
// update gradients
|
||||
opt->gd_state->algorithm->gradients_callback(gd_get_node(opt->gd_state), features, dep_var, &optimizer->weights,
|
||||
&optimizer->gradients);
|
||||
|
||||
elog_matrix(DEBUG1, "optimizer gd: gradients", &optimizer->gradients);
|
||||
|
||||
// add gradients to the model: weight -= alpha * gradients * scale / N
|
||||
matrix_mult_scalar(&optimizer->gradients, opt->learning_rate / features->rows);
|
||||
matrix_subtract(&optimizer->weights, &optimizer->gradients);
|
||||
|
||||
elog_matrix(DEBUG1, "optimizer gd: weights", &optimizer->weights);
|
||||
}
|
||||
|
||||
static void opt_gd_release(OptimizerGD *optimizer)
|
||||
{
|
||||
pfree(optimizer);
|
||||
}
|
||||
|
||||
OptimizerGD *gd_init_optimizer_gd(const GradientDescentState *gd_state)
|
||||
{
|
||||
OptimizerMinibatch *opt = (OptimizerMinibatch *)palloc0(sizeof(OptimizerMinibatch));
|
||||
opt->opt.start_iteration = nullptr;
|
||||
opt->opt.end_iteration = opt_gd_end_iteration;
|
||||
opt->opt.update_batch = opt_gd_update_batch;
|
||||
opt->opt.release = opt_gd_release;
|
||||
opt->gd_state = gd_state;
|
||||
opt->learning_rate = gd_get_node(gd_state)->learning_rate;
|
||||
return &opt->opt;
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* optimizer_ngd.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/optimizer_ngd.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
// ////////////////////////////////////////////////////////////////////////
|
||||
// ngd: normalized gradient descent optimizer
|
||||
//
|
||||
// An adaptation of NG algorithm from:
|
||||
// Ross, Stéphane, Paul Mineiro, and John Langford.
|
||||
// "Normalized online learning." arXiv preprint arXiv:1305.6646 (2013).
|
||||
|
||||
|
||||
typedef struct OptimizerNormalize {
|
||||
OptimizerGD opt;
|
||||
const GradientDescentState *gd_state;
|
||||
double learning_rate;
|
||||
bool learn; // only first iteration
|
||||
double scale_rate;
|
||||
Matrix scale_gradients;
|
||||
} OptimizerNormalize;
|
||||
|
||||
static void opt_ngd_end_iteration(OptimizerGD *optimizer)
|
||||
{
|
||||
OptimizerNormalize *opt = (OptimizerNormalize *)optimizer;
|
||||
|
||||
// decay the learning rate with decay^iterations
|
||||
opt->learning_rate *= gd_get_node(opt->gd_state)->decay;
|
||||
|
||||
// be sure that learns how to normalize only in the first iteration
|
||||
opt->learn = false;
|
||||
}
|
||||
|
||||
static void opt_ngd_update_batch(OptimizerGD *optimizer, const Matrix *features, const Matrix *dep_var)
|
||||
{
|
||||
OptimizerNormalize *opt = (OptimizerNormalize *)optimizer;
|
||||
|
||||
if (opt->learn) {
|
||||
Assert(features->columns == opt->scale_gradients.rows);
|
||||
|
||||
gd_float *pf = features->data;
|
||||
for (int r = 0; r < features->rows; r++) {
|
||||
gd_float *pw = optimizer->weights.data;
|
||||
gd_float *ps = opt->scale_gradients.data;
|
||||
for (int c = 0; c < features->columns; c++) {
|
||||
gd_float qx = *pf++;
|
||||
qx *= qx;
|
||||
if (qx > *ps) {
|
||||
// update weights and scaling of gradients
|
||||
*pw *= *ps / qx;
|
||||
*ps = qx;
|
||||
}
|
||||
if (*ps > 0) {
|
||||
// update scale rate
|
||||
opt->scale_rate += qx / *ps;
|
||||
}
|
||||
ps++;
|
||||
pw++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clear gradients of the batch
|
||||
matrix_zeroes(&optimizer->gradients);
|
||||
opt->gd_state->algorithm->gradients_callback(gd_get_node(opt->gd_state), features, dep_var, &optimizer->weights,
|
||||
&optimizer->gradients);
|
||||
|
||||
elog_matrix(DEBUG1, "optimizer ngd: gradients", &optimizer->gradients);
|
||||
|
||||
// normalize gradients
|
||||
gd_float *pg = optimizer->gradients.data;
|
||||
gd_float *ps = opt->scale_gradients.data;
|
||||
for (int r = 0; r < opt->scale_gradients.rows; r++) {
|
||||
gd_float s = 0.0;
|
||||
if (*ps > 0)
|
||||
s = (1.0 / opt->scale_rate) / *ps;
|
||||
|
||||
*pg *= s;
|
||||
|
||||
ps++;
|
||||
pg++;
|
||||
}
|
||||
|
||||
// add gradients to the model: weight -= alpha * scale_rate * gradients * scale_gradients
|
||||
// do not divide by the number of rows like in a simple minibatch
|
||||
matrix_mult_scalar(&optimizer->gradients, opt->learning_rate);
|
||||
matrix_subtract(&optimizer->weights, &optimizer->gradients);
|
||||
|
||||
elog_matrix(DEBUG1, "optimizer ngd: weights", &optimizer->weights);
|
||||
}
|
||||
|
||||
static void opt_ngd_release(OptimizerGD *optimizer)
|
||||
{
|
||||
pfree(optimizer);
|
||||
}
|
||||
|
||||
OptimizerGD *gd_init_optimizer_ngd(const GradientDescentState *gd_state)
|
||||
{
|
||||
OptimizerNormalize *opt = (OptimizerNormalize *)palloc0(sizeof(OptimizerNormalize));
|
||||
opt->opt.start_iteration = nullptr;
|
||||
opt->opt.end_iteration = opt_ngd_end_iteration;
|
||||
opt->opt.update_batch = opt_ngd_update_batch;
|
||||
opt->opt.release = opt_ngd_release;
|
||||
opt->gd_state = gd_state;
|
||||
opt->learning_rate = gd_get_node(gd_state)->learning_rate;
|
||||
opt->learn = true;
|
||||
opt->scale_rate = 0.0;
|
||||
matrix_init(&opt->scale_gradients, gd_state->n_features);
|
||||
return &opt->opt;
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* predict.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/predict.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
#include "db4ai/gd.h"
|
||||
#include "db4ai/model_warehouse.h"
|
||||
#include "db4ai/predict_by.h"
|
||||
|
||||
typedef struct GradientDescentPredictor {
|
||||
GradientDescentAlgorithm *algorithm;
|
||||
int ncategories;
|
||||
Oid return_type;
|
||||
Matrix weights;
|
||||
Matrix features;
|
||||
Datum *categories;
|
||||
} GradientDescentPredictor;
|
||||
|
||||
ModelPredictor gd_predict_prepare(const Model *model)
|
||||
{
|
||||
GradientDescentPredictor *gdp = (GradientDescentPredictor *)palloc0(sizeof(GradientDescentPredictor));
|
||||
ModelGradientDescent *gdp_model = (ModelGradientDescent *)model;
|
||||
gdp->algorithm = gd_get_algorithm(gdp_model->model.algorithm);
|
||||
gdp->ncategories = gdp_model->ncategories;
|
||||
gdp->return_type = gdp_model->model.return_type;
|
||||
|
||||
ArrayType *arr = (ArrayType *)pg_detoast_datum((struct varlena *)DatumGetPointer(gdp_model->weights));
|
||||
Assert(arr->elemtype == FLOAT4OID);
|
||||
|
||||
int coefficients = ARR_DIMS(arr)[0];
|
||||
matrix_init(&gdp->weights, coefficients);
|
||||
matrix_init(&gdp->features, coefficients);
|
||||
|
||||
Datum dt;
|
||||
bool isnull;
|
||||
gd_float *pf = gdp->weights.data;
|
||||
ArrayIterator it = array_create_iterator(arr, 0);
|
||||
while (array_iterate(it, &dt, &isnull)) {
|
||||
Assert(!isnull);
|
||||
*pf++ = DatumGetFloat4(dt);
|
||||
}
|
||||
array_free_iterator(it);
|
||||
if (arr != (ArrayType *)DatumGetPointer(gdp_model->weights))
|
||||
pfree(arr);
|
||||
|
||||
if (gdp->ncategories > 0) {
|
||||
arr = (ArrayType *)pg_detoast_datum((struct varlena *)DatumGetPointer(gdp_model->categories));
|
||||
gdp->categories = (Datum *)palloc(ARR_DIMS(arr)[0] * sizeof(Datum));
|
||||
|
||||
int cat = 0;
|
||||
it = array_create_iterator(arr, 0);
|
||||
while (array_iterate(it, &dt, &isnull)) {
|
||||
Assert(!isnull);
|
||||
gdp->categories[cat++] = dt;
|
||||
}
|
||||
array_free_iterator(it);
|
||||
if (arr != (ArrayType *)DatumGetPointer(gdp_model->categories))
|
||||
pfree(arr);
|
||||
} else
|
||||
gdp->categories = nullptr;
|
||||
|
||||
return (ModelPredictor *)gdp;
|
||||
}
|
||||
|
||||
Datum gd_predict(ModelPredictor pred, Datum *values, bool *isnull, Oid *types, int ncolumns)
|
||||
{
|
||||
// extract coefficients from model
|
||||
GradientDescentPredictor *gdp = (GradientDescentPredictor *)pred;
|
||||
|
||||
// extract the features
|
||||
if (ncolumns != (int)gdp->weights.rows - 1)
|
||||
elog(ERROR, "Invalid number of features for prediction, provided %d, expected %d", ncolumns,
|
||||
gdp->weights.rows - 1);
|
||||
|
||||
gd_float *w = gdp->features.data;
|
||||
for (int i = 0; i < ncolumns; i++) {
|
||||
double value;
|
||||
if (isnull[i])
|
||||
value = 0.0; // default value for feature, it is not the target for sure
|
||||
else
|
||||
value = gd_datum_get_float(types[i], values[i]);
|
||||
|
||||
*w++ = value;
|
||||
}
|
||||
*w = 1.0; // bias
|
||||
|
||||
Datum result = 0;
|
||||
gd_float r = gdp->algorithm->predict_callback(&gdp->features, &gdp->weights);
|
||||
if (dep_var_is_binary(gdp->algorithm)) {
|
||||
if (gdp->ncategories == 0) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INTERNAL_ERROR),
|
||||
errmsg("For classification algorithms: %s, the number of categories should not be 0.",
|
||||
gdp->algorithm->name)));
|
||||
}
|
||||
result = gdp->categories[0];
|
||||
if (r != gdp->algorithm->min_class) {
|
||||
if (gdp->ncategories == 2 && r == gdp->algorithm->max_class)
|
||||
result = gdp->categories[1];
|
||||
}
|
||||
} else {
|
||||
result = gd_float_get_datum(gdp->return_type, r);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
|
@ -0,0 +1,280 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* shuffle_cache.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/shuffle_cache.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
/*
|
||||
* Shuffle using a limited cache is performed by caching and training batches
|
||||
* in random order. For each new batch there are two random options:
|
||||
* - append (there are free slots into the cache)
|
||||
* - or train (and test) an existing batch and replace it by the new one
|
||||
* At the end of the iteration, the cache is emptied randomly. At each step,
|
||||
* one remaining batch is selected, trained and removed.
|
||||
* At the end of the two phases, all batches have been visited only once in a
|
||||
* random sequence, but the ability to shuffle batches depends on the cache size
|
||||
* (the available working memory) and the batch size (a matrix of N rows
|
||||
* by M features). The probability distribution is not uniform and initial
|
||||
* batches have a higher probability than the last batches, but the shuffling
|
||||
* is good enough and has a very small impact into the performance of GD.
|
||||
*/
|
||||
|
||||
typedef struct ShuffleCache {
|
||||
ShuffleGD shf;
|
||||
int cache_size;
|
||||
int *cache_batch;
|
||||
Matrix **cache_features;
|
||||
Matrix **cache_dep_var;
|
||||
int cache_allocated;
|
||||
int max_cache_usage;
|
||||
int num_batches;
|
||||
int batch_size;
|
||||
int n_features;
|
||||
int iteration;
|
||||
int cached;
|
||||
int next;
|
||||
struct drand48_data rnd;
|
||||
} ShuffleCache;
|
||||
|
||||
inline int32_t rnd_next(ShuffleCache *shf) {
|
||||
long int r;
|
||||
lrand48_r(&shf->rnd, &r);
|
||||
return r;
|
||||
}
|
||||
|
||||
static void update_batch(ShuffleCache *cache)
|
||||
{
|
||||
Assert(cache->next < cache->cached);
|
||||
|
||||
ereport(DEBUG1, (errmodule(MOD_DB4AI), errmsg("GD shuffle cache iteration %d train batch %d of %d",
|
||||
cache->iteration, cache->cache_batch[cache->next] + 1, cache->num_batches)));
|
||||
|
||||
Matrix *features = cache->cache_features[cache->next];
|
||||
Matrix *dep_var = cache->cache_dep_var[cache->next];
|
||||
|
||||
cache->shf.optimizer->update_batch(cache->shf.optimizer, features, dep_var);
|
||||
|
||||
if (features->rows < cache->batch_size) {
|
||||
matrix_resize(features, cache->batch_size, cache->n_features);
|
||||
matrix_resize(dep_var, cache->batch_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
static void swap_last(ShuffleCache *cache)
|
||||
{
|
||||
Matrix *features = cache->cache_features[cache->next];
|
||||
cache->cache_features[cache->next] = cache->cache_features[cache->cached];
|
||||
cache->cache_features[cache->cached] = features;
|
||||
|
||||
Matrix *dep_var = cache->cache_dep_var[cache->next];
|
||||
cache->cache_dep_var[cache->next] = cache->cache_dep_var[cache->cached];
|
||||
cache->cache_dep_var[cache->cached] = dep_var;
|
||||
|
||||
cache->cache_batch[cache->next] = cache->cache_batch[cache->cached];
|
||||
}
|
||||
|
||||
static void cache_start_iteration(ShuffleGD *shuffle)
|
||||
{
|
||||
ShuffleCache *cache = (ShuffleCache *)shuffle;
|
||||
cache->next = -1;
|
||||
cache->cached = 0;
|
||||
cache->num_batches = 0;
|
||||
}
|
||||
|
||||
static void cache_end_iteration(ShuffleGD *shuffle)
|
||||
{
|
||||
ShuffleCache *cache = (ShuffleCache *)shuffle;
|
||||
if (cache->iteration == 0) {
|
||||
cache->max_cache_usage = cache->cache_size;
|
||||
if (cache->max_cache_usage > cache->num_batches)
|
||||
cache->max_cache_usage = cache->num_batches;
|
||||
} else {
|
||||
// empty the cache
|
||||
while (cache->cached > 0) {
|
||||
cache->next = rnd_next(cache) % cache->cached;
|
||||
update_batch(cache);
|
||||
|
||||
cache->cached--;
|
||||
if (cache->next != cache->cached)
|
||||
swap_last(cache);
|
||||
}
|
||||
}
|
||||
cache->iteration++;
|
||||
}
|
||||
|
||||
static void cache_release(ShuffleGD *shuffle)
|
||||
{
|
||||
ShuffleCache *cache = (ShuffleCache *)shuffle;
|
||||
for (int c = 0; c < cache->cache_allocated; c++) {
|
||||
matrix_release(cache->cache_features[c]);
|
||||
pfree(cache->cache_features[c]);
|
||||
|
||||
matrix_release(cache->cache_dep_var[c]);
|
||||
pfree(cache->cache_dep_var[c]);
|
||||
}
|
||||
pfree(cache->cache_features);
|
||||
pfree(cache->cache_dep_var);
|
||||
pfree(cache->cache_batch);
|
||||
pfree(cache);
|
||||
}
|
||||
|
||||
static Matrix *cache_get(ShuffleGD *shuffle, Matrix **pdep_var)
|
||||
{
|
||||
ShuffleCache *cache = (ShuffleCache *)shuffle;
|
||||
Assert(cache->next == -1);
|
||||
|
||||
Matrix *features;
|
||||
Matrix *dep_var;
|
||||
if (cache->iteration == 0) {
|
||||
// special case, do not shuffle
|
||||
Assert(cache->cached == 0);
|
||||
cache->next = 0;
|
||||
cache->cached++;
|
||||
if (cache->cache_allocated == 0) {
|
||||
features = (Matrix *)palloc0(sizeof(Matrix));
|
||||
matrix_init(features, cache->batch_size, cache->n_features);
|
||||
|
||||
dep_var = (Matrix *)palloc0(sizeof(Matrix));
|
||||
matrix_init(dep_var, cache->batch_size);
|
||||
|
||||
cache->cache_features[0] = features;
|
||||
cache->cache_dep_var[0] = dep_var;
|
||||
cache->cache_allocated++;
|
||||
} else {
|
||||
// reuse the batch, it has been already
|
||||
features = cache->cache_features[0];
|
||||
dep_var = cache->cache_dep_var[0];
|
||||
}
|
||||
} else {
|
||||
// look for an empty slot, otherwise reuse one
|
||||
cache->next = rnd_next(cache) % cache->max_cache_usage;
|
||||
if (cache->next < cache->cached) {
|
||||
// reuse slot
|
||||
update_batch(cache);
|
||||
features = cache->cache_features[cache->next];
|
||||
dep_var = cache->cache_dep_var[cache->next];
|
||||
} else {
|
||||
// append
|
||||
cache->next = cache->cached++;
|
||||
if (cache->next == cache->cache_allocated) {
|
||||
features = (Matrix *)palloc0(sizeof(Matrix));
|
||||
matrix_init(features, cache->batch_size, cache->n_features);
|
||||
|
||||
dep_var = (Matrix *)palloc0(sizeof(Matrix));
|
||||
matrix_init(dep_var, cache->batch_size);
|
||||
|
||||
cache->cache_features[cache->next] = features;
|
||||
cache->cache_dep_var[cache->next] = dep_var;
|
||||
cache->cache_allocated++;
|
||||
} else {
|
||||
features = cache->cache_features[cache->next];
|
||||
dep_var = cache->cache_dep_var[cache->next];
|
||||
}
|
||||
}
|
||||
}
|
||||
cache->cache_batch[cache->next] = cache->num_batches;
|
||||
cache->num_batches++;
|
||||
|
||||
*pdep_var = dep_var;
|
||||
return features;
|
||||
}
|
||||
|
||||
static void cache_unget(ShuffleGD *shuffle, int tuples)
|
||||
{
|
||||
ShuffleCache *cache = (ShuffleCache *)shuffle;
|
||||
Assert(cache->next != -1);
|
||||
Assert(cache->cached > 0);
|
||||
|
||||
if (tuples == 0) {
|
||||
// ignore batch
|
||||
if (cache->iteration == 0) {
|
||||
cache->cached--;
|
||||
} else {
|
||||
// special case when the last batch is empty
|
||||
// there are two potential cases
|
||||
cache->cached--;
|
||||
if (cache->next < cache->cached) {
|
||||
// it is in the middle, swap it with the last
|
||||
swap_last(cache);
|
||||
}
|
||||
}
|
||||
cache->num_batches--;
|
||||
} else {
|
||||
if (tuples < cache->batch_size) {
|
||||
// resize batch temporarily
|
||||
Matrix *features = cache->cache_features[cache->next];
|
||||
Matrix *dep_var = cache->cache_dep_var[cache->next];
|
||||
matrix_resize(features, tuples, cache->n_features);
|
||||
matrix_resize(dep_var, tuples, 1);
|
||||
}
|
||||
|
||||
if (cache->iteration == 0) {
|
||||
Assert(cache->next == 0);
|
||||
update_batch(cache);
|
||||
cache->cached--;
|
||||
}
|
||||
}
|
||||
|
||||
cache->next = -1;
|
||||
}
|
||||
|
||||
ShuffleGD *gd_init_shuffle_cache(const GradientDescentState *gd_state)
|
||||
{
|
||||
int batch_size = gd_get_node(gd_state)->batch_size;
|
||||
|
||||
// check if a batch fits into memory
|
||||
int64_t avail_mem = u_sess->attr.attr_memory.work_mem * 1024LL -
|
||||
matrix_expected_size(gd_state->n_features) * 2; // weights & gradients
|
||||
|
||||
int batch_mem = matrix_expected_size(batch_size, gd_state->n_features) // features
|
||||
+ matrix_expected_size(gd_state->n_features); // dep var
|
||||
if (batch_mem > avail_mem)
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("batch size is too large for the available working memory")));
|
||||
|
||||
// initialize
|
||||
ShuffleCache *cache = (ShuffleCache *)palloc0(sizeof(ShuffleCache));
|
||||
srand48_r(gd_get_node(gd_state)->seed, &cache->rnd);
|
||||
cache->shf.start_iteration = cache_start_iteration;
|
||||
cache->shf.end_iteration = cache_end_iteration;
|
||||
cache->shf.release = cache_release;
|
||||
cache->shf.get = cache_get;
|
||||
cache->shf.unget = cache_unget;
|
||||
|
||||
// cache for shuffle
|
||||
cache->cache_size = avail_mem / (batch_mem + 2 * sizeof(Matrix *) + sizeof(int));
|
||||
if (cache->cache_size == 0)
|
||||
cache->cache_size = 1; // shuffle is not possible
|
||||
|
||||
ereport(NOTICE, (errmodule(MOD_DB4AI), errmsg("GD shuffle cache size %d", cache->cache_size)));
|
||||
|
||||
cache->batch_size = batch_size;
|
||||
cache->n_features = gd_state->n_features;
|
||||
cache->cache_batch = (int *)palloc(cache->cache_size * sizeof(int));
|
||||
cache->cache_features = (Matrix **)palloc(cache->cache_size * sizeof(Matrix *));
|
||||
cache->cache_dep_var = (Matrix **)palloc(cache->cache_size * sizeof(Matrix *));
|
||||
cache->cache_allocated = 0;
|
||||
cache->max_cache_usage = 0;
|
||||
cache->iteration = 0;
|
||||
|
||||
return &cache->shf;
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* svm.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/gd/svm.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
static void svmc_gradients(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
|
||||
Matrix *weights, Matrix *gradients)
|
||||
{
|
||||
Assert(features->rows > 0);
|
||||
|
||||
// distances = 1 - y * (x * w)
|
||||
Matrix distance;
|
||||
matrix_init(&distance, features->rows);
|
||||
matrix_mult_vector(features, weights, &distance);
|
||||
matrix_mult_entrywise(&distance, dep_var);
|
||||
matrix_complement(&distance);
|
||||
|
||||
Assert(distance.rows == dep_var->rows);
|
||||
Assert(distance.columns == 1);
|
||||
|
||||
const gd_float *pf = features->data;
|
||||
const gd_float *py = dep_var->data;
|
||||
const gd_float *pd = distance.data;
|
||||
gd_float *pg = gradients->data;
|
||||
for (int r = 0; r < distance.rows; r++) {
|
||||
gd_float y = *py++;
|
||||
gd_float d = *pd++;
|
||||
if (d > 0) {
|
||||
for (int f = 0; f < features->columns; f++)
|
||||
pg[f] -= y * pf[f];
|
||||
}
|
||||
pf += features->columns;
|
||||
}
|
||||
matrix_release(&distance);
|
||||
|
||||
matrix_mult_scalar(gradients, gd_node->lambda * 2.0);
|
||||
}
|
||||
|
||||
static double svmc_test(const GradientDescent *gd_node, const Matrix *features, const Matrix *dep_var,
|
||||
const Matrix *weights, Scores *scores)
|
||||
{
|
||||
Assert(features->rows > 0);
|
||||
|
||||
Matrix distances;
|
||||
Matrix predictions;
|
||||
matrix_init(&distances, features->rows);
|
||||
matrix_mult_vector(features, weights, &distances);
|
||||
|
||||
matrix_init_clone(&predictions, &distances);
|
||||
matrix_positive(&predictions);
|
||||
matrix_binary(&predictions, FLT_MIN, -1.0, 1.0);
|
||||
matrix_relevance(&predictions, dep_var, scores, 1.0);
|
||||
matrix_release(&predictions);
|
||||
|
||||
// cost = (1 - y * x * w)
|
||||
matrix_mult_entrywise(&distances, dep_var);
|
||||
matrix_complement(&distances);
|
||||
matrix_positive(&distances);
|
||||
|
||||
gd_float tuple_loss = matrix_get_sum(&distances) / features->rows;
|
||||
|
||||
matrix_release(&distances);
|
||||
return tuple_loss;
|
||||
}
|
||||
|
||||
static gd_float svmc_predict(const Matrix *features, const Matrix *weights)
|
||||
{
|
||||
double r = matrix_dot(features, weights);
|
||||
return r < 0 ? -1.0 : 1.0;
|
||||
}
|
||||
|
||||
GradientDescentAlgorithm gd_svm_classification = {
|
||||
"svm-classification",
|
||||
GD_DEPENDENT_VAR_BINARY,
|
||||
METRIC_ACCURACY | METRIC_F1 | METRIC_PRECISION | METRIC_RECALL | METRIC_LOSS,
|
||||
-1.0,
|
||||
1.0,
|
||||
svmc_gradients,
|
||||
svmc_test,
|
||||
svmc_predict,
|
||||
};
|
|
@ -0,0 +1,401 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*
|
||||
*
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/executor/hyperparameter_validation.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "db4ai/hyperparameter_validation.h"
|
||||
|
||||
#include "nodes/plannodes.h"
|
||||
|
||||
|
||||
#define ARRAY_LENGTH(x) sizeof(x) / sizeof((x)[0])
|
||||
|
||||
static void init_hyperparameter_validation(HyperparameterValidation *v, void *min_value, bool min_inclusive,
|
||||
void *max_value, bool max_inclusive, const char **valid_values, int32_t valid_values_size)
|
||||
{
|
||||
v->min_value = min_value;
|
||||
v->min_inclusive = min_inclusive;
|
||||
v->max_value = max_value;
|
||||
v->max_inclusive = max_inclusive;
|
||||
v->valid_values = valid_values;
|
||||
v->valid_values_size = valid_values_size;
|
||||
}
|
||||
|
||||
// Definitions of hyperparameters
|
||||
#define HYPERPARAMETER_BOOL(name, default_value, struct_name, attribute) \
|
||||
{ \
|
||||
name, BoolGetDatum(default_value), PointerGetDatum(NULL), PointerGetDatum(NULL), NULL, NULL, BOOLOID, 0, \
|
||||
offsetof(struct_name, attribute), false, false \
|
||||
}
|
||||
|
||||
|
||||
#define HYPERPARAMETER_ENUM(name, default_value, enum_values, enum_values_size, enum_setter, struct_name, attribute) \
|
||||
{ \
|
||||
name, CStringGetDatum(default_value), PointerGetDatum(NULL), PointerGetDatum(NULL), enum_values, enum_setter, \
|
||||
ANYENUMOID, enum_values_size, offsetof(struct_name, attribute), false, false \
|
||||
}
|
||||
|
||||
|
||||
#define HYPERPARAMETER_INT4(name, default_value, min, min_inclusive, max, max_inclusive, struct_name, attribute) \
|
||||
{ \
|
||||
name, Int32GetDatum(default_value), Int32GetDatum(min), Int32GetDatum(max), NULL, NULL, INT4OID, 0, \
|
||||
offsetof(struct_name, attribute), min_inclusive, max_inclusive \
|
||||
}
|
||||
|
||||
#define HYPERPARAMETER_FLOAT8(name, default_value, min, min_inclusive, max, max_inclusive, struct_name, attribute) \
|
||||
{ \
|
||||
name, Float8GetDatum(default_value), Float8GetDatum(min), Float8GetDatum(max), NULL, NULL, FLOAT8OID, 0, \
|
||||
offsetof(struct_name, attribute), min_inclusive, max_inclusive \
|
||||
}
|
||||
|
||||
const char* gd_optimizer_ml[] = {"gd", "ngd"};
|
||||
// Used by linear regression and logistic regression
|
||||
HyperparameterDefinition logistic_regression_hyperparameter_definitions[] = {
|
||||
HYPERPARAMETER_INT4("batch_size", 1000, 1, true, INT32_MAX, true,
|
||||
GradientDescent, batch_size),
|
||||
HYPERPARAMETER_FLOAT8("decay", 0.95, 0.0, false, DBL_MAX, true,
|
||||
GradientDescent, decay),
|
||||
HYPERPARAMETER_FLOAT8("learning_rate", 0.8, 0.0, false, DBL_MAX, true,
|
||||
GradientDescent, learning_rate),
|
||||
HYPERPARAMETER_INT4("max_iterations", 100, 1, true, INT32_MAX, true,
|
||||
GradientDescent, max_iterations),
|
||||
HYPERPARAMETER_INT4("max_seconds", 0, 0, true, INT32_MAX, true,
|
||||
GradientDescent, max_seconds),
|
||||
HYPERPARAMETER_ENUM("optimizer", "gd", gd_optimizer_ml, ARRAY_LENGTH(gd_optimizer_ml), optimizer_ml_setter,
|
||||
GradientDescent, optimizer),
|
||||
HYPERPARAMETER_FLOAT8("tolerance", 0.0005, 0.0, false, DBL_MAX, true,
|
||||
GradientDescent, tolerance),
|
||||
HYPERPARAMETER_INT4("seed", 0, 0, true, INT32_MAX, true,
|
||||
GradientDescent, seed),
|
||||
HYPERPARAMETER_BOOL("verbose", false,
|
||||
GradientDescent, verbose),
|
||||
};
|
||||
|
||||
HyperparameterDefinition svm_hyperparameter_definitions[] = {
|
||||
HYPERPARAMETER_INT4("batch_size", 1000, 1, true, INT32_MAX, true,
|
||||
GradientDescent, batch_size),
|
||||
HYPERPARAMETER_FLOAT8("decay", 0.95, 0.0, false, DBL_MAX, true,
|
||||
GradientDescent, decay),
|
||||
HYPERPARAMETER_FLOAT8("lambda", 0.01, 0.0, false, DBL_MAX, true,
|
||||
GradientDescent, lambda),
|
||||
HYPERPARAMETER_FLOAT8("learning_rate", 0.8, 0.0, false, DBL_MAX, true,
|
||||
GradientDescent, learning_rate),
|
||||
HYPERPARAMETER_INT4("max_iterations", 100, 1, true, INT32_MAX, true,
|
||||
GradientDescent, max_iterations),
|
||||
HYPERPARAMETER_INT4("max_seconds", 0, 0, true, INT32_MAX, true,
|
||||
GradientDescent, max_seconds),
|
||||
HYPERPARAMETER_ENUM("optimizer", "gd", gd_optimizer_ml, ARRAY_LENGTH(gd_optimizer_ml), optimizer_ml_setter,
|
||||
GradientDescent, optimizer),
|
||||
HYPERPARAMETER_FLOAT8("tolerance", 0.0005, 0.0, false, DBL_MAX, true,
|
||||
GradientDescent, tolerance),
|
||||
HYPERPARAMETER_INT4("seed", 0, 0, true, INT32_MAX, true,
|
||||
GradientDescent, seed),
|
||||
HYPERPARAMETER_BOOL("verbose", false,
|
||||
GradientDescent, verbose),
|
||||
};
|
||||
|
||||
HyperparameterDefinition kmeans_hyperparameter_definitions[] = {
|
||||
/* nothing to do now, will do when needing */
|
||||
};
|
||||
|
||||
|
||||
void get_hyperparameter_definitions(AlgorithmML algorithm, HyperparameterDefinition **result, int32_t *result_size)
|
||||
{
|
||||
switch (algorithm) {
|
||||
case LOGISTIC_REGRESSION:
|
||||
case LINEAR_REGRESSION:
|
||||
*result = logistic_regression_hyperparameter_definitions;
|
||||
*result_size = ARRAY_LENGTH(logistic_regression_hyperparameter_definitions);
|
||||
break;
|
||||
|
||||
case SVM_CLASSIFICATION:
|
||||
*result = svm_hyperparameter_definitions;
|
||||
*result_size = ARRAY_LENGTH(svm_hyperparameter_definitions);
|
||||
break;
|
||||
|
||||
case KMEANS:
|
||||
*result = kmeans_hyperparameter_definitions;
|
||||
*result_size = ARRAY_LENGTH(kmeans_hyperparameter_definitions);
|
||||
break;
|
||||
|
||||
case INVALID_ALGORITHM_ML:
|
||||
default:
|
||||
char *s = "logistic_regression, svm_classification, linear_regression, kmeans";
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Architecture is not supported. Supported architectures: %s", s)));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void add_model_hyperparameter(Model *model, const char *name, Oid type, Datum value)
|
||||
{
|
||||
Hyperparameter *hyperp = (Hyperparameter *)palloc0(sizeof(Hyperparameter));
|
||||
hyperp->name = pstrdup(name);
|
||||
hyperp->type = type;
|
||||
hyperp->value = value;
|
||||
model->hyperparameters = lappend(model->hyperparameters, hyperp);
|
||||
}
|
||||
|
||||
void update_model_hyperparameter(Model *model, const char *name, Oid type, Datum value) {
|
||||
ListCell *lc;
|
||||
foreach(lc, model->hyperparameters) {
|
||||
Hyperparameter *hyperp = lfirst_node(Hyperparameter, lc);
|
||||
if (strcmp(hyperp->name, name) == 0) {
|
||||
hyperp->type = type;
|
||||
hyperp->value = value;
|
||||
return;
|
||||
}
|
||||
}
|
||||
Assert(false);
|
||||
}
|
||||
|
||||
// Set int hyperparameter
|
||||
void set_hyperparameter_value(const char *name, int *hyperparameter, Value *value, VariableSetKind kind,
|
||||
int default_value, Model *model, HyperparameterValidation *validation)
|
||||
{
|
||||
if (kind == VAR_SET_DEFAULT) {
|
||||
*hyperparameter = default_value;
|
||||
ereport(NOTICE, (errmsg("Hyperparameter %s takes value DEFAULT (%d)", name, *hyperparameter)));
|
||||
} else if (kind == VAR_SET_VALUE) {
|
||||
if (value == NULL) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s cannot take NULL value", name)));
|
||||
} else if (value->type != T_Integer) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s must be an integer", name)));
|
||||
}
|
||||
*hyperparameter = intVal(value);
|
||||
ereport(NOTICE, (errmsg("Hyperparameter %s takes value %d", name, *hyperparameter)));
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Invalid hyperparameter value for %s ", name)));
|
||||
}
|
||||
|
||||
add_model_hyperparameter(model, name, INT4OID, Int32GetDatum(*hyperparameter));
|
||||
if (validation->min_value != NULL && validation->max_value != NULL) {
|
||||
bool out_of_range =
|
||||
(*hyperparameter < *(int *)validation->min_value || *hyperparameter > *(int *)validation->max_value);
|
||||
if (!validation->min_inclusive && *hyperparameter <= *(int *)validation->min_value)
|
||||
out_of_range = true;
|
||||
if (!validation->max_inclusive && *hyperparameter >= *(int *)validation->max_value)
|
||||
out_of_range = true;
|
||||
if (out_of_range) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s must be in the range %c%d,%d%c", name, validation->min_inclusive ? '[' : '(',
|
||||
*(int *)validation->min_value, *(int *)validation->max_value, validation->max_inclusive ? ']' : ')')));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Set double hyperparameter
|
||||
void set_hyperparameter_value(const char *name, double *hyperparameter, Value *value, VariableSetKind kind,
|
||||
double default_value, Model *model, HyperparameterValidation *validation)
|
||||
{
|
||||
if (kind == VAR_SET_DEFAULT) {
|
||||
*hyperparameter = default_value;
|
||||
ereport(NOTICE, (errmsg("Hyperparameter %s takes value DEFAULT (%f)", name, *hyperparameter)));
|
||||
} else if (kind == VAR_SET_VALUE) {
|
||||
if (value == NULL) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s cannot take NULL value", name)));
|
||||
} else if (value->type == T_Float) {
|
||||
*hyperparameter = floatVal(value);
|
||||
} else if (value->type == T_Integer) {
|
||||
*hyperparameter = (double)intVal(value);
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s must be a floating point number", name)));
|
||||
}
|
||||
ereport(NOTICE, (errmsg("Hyperparameter %s takes value %f", name, *hyperparameter)));
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Invalid hyperparameter value for %s ", name)));
|
||||
}
|
||||
|
||||
add_model_hyperparameter(model, name, FLOAT8OID, Float8GetDatum(*hyperparameter));
|
||||
if (validation->min_value != NULL && validation->max_value != NULL) {
|
||||
bool out_of_range =
|
||||
(*hyperparameter < *(double *)validation->min_value || *hyperparameter > *(double *)validation->max_value);
|
||||
if (!validation->min_inclusive && *hyperparameter <= *(double *)validation->min_value)
|
||||
out_of_range = true;
|
||||
if (!validation->max_inclusive && *hyperparameter >= *(double *)validation->max_value)
|
||||
out_of_range = true;
|
||||
if (out_of_range) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s must be in the range %c%.8g,%.8g%c", name,
|
||||
validation->min_inclusive ? '[' : '(', *(double *)validation->min_value,
|
||||
*(double *)validation->max_value, validation->max_inclusive ? ']' : ')')));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Set string hyperparameter (no const)
|
||||
void set_hyperparameter_value(const char *name, char **hyperparameter, Value *value, VariableSetKind kind,
|
||||
char *default_value, Model *model, HyperparameterValidation *validation)
|
||||
{
|
||||
if (kind == VAR_SET_DEFAULT) {
|
||||
*hyperparameter = (char*)palloc((strlen(default_value) + 1) * sizeof(char));
|
||||
errno_t err = strcpy_s(*hyperparameter, strlen(default_value) + 1, default_value);
|
||||
securec_check(err, "\0", "\0");
|
||||
ereport(NOTICE, (errmsg("Hyperparameter %s takes value DEFAULT (%s)", name, *hyperparameter)));
|
||||
} else if (kind == VAR_SET_VALUE) {
|
||||
if (value == NULL) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s cannot take NULL value", name)));
|
||||
} else if (value->type != T_String) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s must be a string", name)));
|
||||
}
|
||||
*hyperparameter = strVal(value);
|
||||
ereport(NOTICE, (errmsg("Hyperparameter %s takes value %s", name, *hyperparameter)));
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Invalid hyperparameter value for %s ", name)));
|
||||
}
|
||||
|
||||
if (validation->valid_values != NULL) {
|
||||
bool found = false;
|
||||
for (int i = 0; i < validation->valid_values_size; i++) {
|
||||
if (0 == strcmp(validation->valid_values[i], *hyperparameter)) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
StringInfo str = makeStringInfo();
|
||||
for (int i = 0; i < validation->valid_values_size; i++) {
|
||||
if (i != 0)
|
||||
appendStringInfoString(str, ", ");
|
||||
appendStringInfoString(str, (validation->valid_values)[i]);
|
||||
}
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Invalid hyperparameter value for %s. Valid values are: %s. (default is %s)", name, str->data,
|
||||
default_value)));
|
||||
}
|
||||
}
|
||||
add_model_hyperparameter(model, name, VARCHAROID, PointerGetDatum(cstring_to_text(*hyperparameter)));
|
||||
}
|
||||
|
||||
|
||||
inline const char *bool_to_str(bool value)
|
||||
{
|
||||
return value ? "TRUE" : "FALSE";
|
||||
}
|
||||
|
||||
// Set boolean hyperparameter
|
||||
void set_hyperparameter_value(const char *name, bool *hyperparameter, Value *value, VariableSetKind kind,
|
||||
bool default_value, Model *model, HyperparameterValidation *validation)
|
||||
{
|
||||
if (kind == VAR_SET_DEFAULT) {
|
||||
*hyperparameter = default_value;
|
||||
ereport(NOTICE, (errmsg("Hyperparameter %s takes value DEFAULT (%s)", name, bool_to_str(*hyperparameter))));
|
||||
} else if (kind == VAR_SET_VALUE) {
|
||||
if (value == NULL) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s cannot take NULL value", name)));
|
||||
} else if (value->type == T_String) {
|
||||
char *str = strVal(value);
|
||||
if (strcmp(str, "true") == 0) {
|
||||
*hyperparameter = true;
|
||||
} else if (strcmp(str, "false") == 0) {
|
||||
*hyperparameter = false;
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s is not a valid string for boolean (i.e. 'true' or 'false')", name)));
|
||||
}
|
||||
} else if (value->type == T_Integer) {
|
||||
*hyperparameter = (intVal(value) != 0);
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s must be a boolean or integer", name)));
|
||||
}
|
||||
ereport(NOTICE, (errmsg("Hyperparameter %s takes value %s", name, bool_to_str(*hyperparameter))));
|
||||
} else {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Invalid hyperparameter value for %s ", name)));
|
||||
}
|
||||
|
||||
add_model_hyperparameter(model, name, BOOLOID, BoolGetDatum(*hyperparameter));
|
||||
}
|
||||
|
||||
|
||||
// Set the hyperparameters according to the definitions. In the process, the range and values of
|
||||
// each parameter is validated
|
||||
void configure_hyperparameters(AlgorithmML algorithm, List *hyperparameters, Model *model, void *hyperparameter_struct)
|
||||
{
|
||||
HyperparameterDefinition *definitions;
|
||||
int32_t definitions_size;
|
||||
get_hyperparameter_definitions(algorithm, &definitions, &definitions_size);
|
||||
|
||||
HyperparameterValidation validation;
|
||||
for (int32_t i = 0; i < definitions_size; i++) {
|
||||
switch (definitions[i].type) {
|
||||
case INT4OID: {
|
||||
int *value_addr = (int *)((char *)hyperparameter_struct + definitions[i].offset);
|
||||
int value_min = DatumGetInt32(definitions[i].min_value);
|
||||
int value_max = DatumGetInt32(definitions[i].max_value);
|
||||
init_hyperparameter_validation(&validation, &value_min, definitions[i].min_inclusive, &value_max,
|
||||
definitions[i].max_inclusive, NULL, 0);
|
||||
set_hyperparameter<int>(definitions[i].name, value_addr, hyperparameters,
|
||||
DatumGetInt32(definitions[i].default_value), model, &validation);
|
||||
break;
|
||||
}
|
||||
|
||||
case FLOAT8OID: {
|
||||
double *value_addr = (double *)((char *)hyperparameter_struct + definitions[i].offset);
|
||||
double value_min = DatumGetFloat8(definitions[i].min_value);
|
||||
double value_max = DatumGetFloat8(definitions[i].max_value);
|
||||
init_hyperparameter_validation(&validation, &value_min, definitions[i].min_inclusive, &value_max,
|
||||
definitions[i].max_inclusive, NULL, 0);
|
||||
set_hyperparameter<double>(definitions[i].name, value_addr, hyperparameters,
|
||||
DatumGetFloat8(definitions[i].default_value), model, &validation);
|
||||
break;
|
||||
}
|
||||
|
||||
case BOOLOID: {
|
||||
bool *value_addr = (bool *)((char *)hyperparameter_struct + definitions[i].offset);
|
||||
set_hyperparameter<bool>(definitions[i].name, value_addr, hyperparameters,
|
||||
DatumGetBool(definitions[i].default_value), model, NULL);
|
||||
break;
|
||||
}
|
||||
|
||||
case ANYENUMOID: {
|
||||
void *value_addr = (void *)((char *)hyperparameter_struct + definitions[i].offset);
|
||||
char *str = NULL;
|
||||
init_hyperparameter_validation(&validation, NULL, NULL, NULL, NULL, definitions[i].valid_values,
|
||||
definitions[i].valid_values_size);
|
||||
set_hyperparameter<char *>(definitions[i].name, &str, hyperparameters,
|
||||
DatumGetCString(definitions[i].default_value), model, &validation);
|
||||
definitions[i].enum_setter(str, value_addr);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
default: {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Invalid hyperparameter OID %d for hyperparameter %s", definitions[i].type,
|
||||
definitions[i].name)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
# kmeans.cmake
|
||||
set(TGT_kmeans_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kmeans.cpp
|
||||
)
|
||||
set(TGT_kmeans_INC
|
||||
${PROJECT_OPENGS_DIR}/contrib/log_fdw
|
||||
${PROJECT_TRUNK_DIR}/distribute/bin/kmeanss
|
||||
${PROJECT_SRC_DIR}/include/libcomm
|
||||
${PROJECT_SRC_DIR}/include
|
||||
${PROJECT_SRC_DIR}/lib/gstrace
|
||||
${LZ4_INCLUDE_PATH}
|
||||
${LIBCGROUP_INCLUDE_PATH}
|
||||
${LIBORC_INCLUDE_PATH}
|
||||
${EVENT_INCLUDE_PATH}
|
||||
${PROTOBUF_INCLUDE_PATH}
|
||||
${ZLIB_INCLUDE_PATH}
|
||||
)
|
||||
set(kmeans_DEF_OPTIONS ${MACRO_OPTIONS})
|
||||
set(kmeans_COMPILE_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS} -std=c++14 -fPIE)
|
||||
list(REMOVE_ITEM kmeans_COMPILE_OPTIONS -fPIC)
|
||||
set(kmeans_LINK_OPTIONS ${OPTIMIZE_OPTIONS} ${OS_OPTIONS} ${PROTECT_OPTIONS} ${WARNING_OPTIONS} ${SECURE_OPTIONS} ${CHECK_OPTIONS})
|
||||
add_static_objtarget(gausskernel_db4ai_executor_kmeans TGT_kmeans_SRC TGT_kmeans_INC "${kmeans_DEF_OPTIONS}" "${kmeans_COMPILE_OPTIONS}" "${kmeans_LINK_OPTIONS}")
|
|
@ -0,0 +1,23 @@
|
|||
#---------------------------------------------------------------------------------------
|
||||
#
|
||||
# IDENTIFICATION
|
||||
# src/gausskernel/dbmind/db4ai/executor/kmeans
|
||||
#
|
||||
# ---------------------------------------------------------------------------------------
|
||||
|
||||
subdir = src/gausskernel/dbmind/db4ai/executor/kmeans
|
||||
top_builddir = ../../../../../..
|
||||
|
||||
include $(top_builddir)/src/Makefile.global
|
||||
|
||||
ifneq "$(MAKECMDGOALS)" "clean"
|
||||
ifneq "$(MAKECMDGOALS)" "distclean"
|
||||
ifneq "$(shell which g++ |grep hutaf_llt |wc -l)" "1"
|
||||
-include $(DEPEND)
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
|
||||
OBJS = kmeans.o
|
||||
|
||||
include $(top_srcdir)/src/gausskernel/common.mk
|
|
@ -0,0 +1,699 @@
|
|||
/**
|
||||
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
|
||||
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
---------------------------------------------------------------------------------------
|
||||
|
||||
kmeans.cpp
|
||||
k-means operations
|
||||
|
||||
|
||||
IDENTIFICATION
|
||||
src/gausskernel/dbmind/db4ai/executor/kmeans/kmeans.cpp
|
||||
|
||||
---------------------------------------------------------------------------------------
|
||||
* */
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "nodes/execnodes.h"
|
||||
#include "nodes/pg_list.h"
|
||||
#include "postgres_ext.h"
|
||||
|
||||
#include "db4ai/fp_ops.h"
|
||||
#include "db4ai/distance_functions.h"
|
||||
#include "db4ai/model_warehouse.h"
|
||||
#include "db4ai/predict_by.h"
|
||||
#include "db4ai/db4ai_cpu.h"
|
||||
|
||||
/*
|
||||
* to verify that the pg array we get complies with what we expect
|
||||
* we discard entries that do not pass the test because we cannot start
|
||||
* much (consistent) with them anyway
|
||||
*/
|
||||
bool verify_pgarray(ArrayType const * pg_array, int32_t n)
|
||||
{
|
||||
/*
|
||||
* We expect the input to be an n-element array of doubles; verify that. We
|
||||
* don't need to use deconstruct_array() since the array data is just
|
||||
* going to look like a C array of n double values.
|
||||
*/
|
||||
if (unlikely((ARR_NDIM(pg_array) != 1) || (ARR_DIMS(pg_array)[0] != n) || ARR_HASNULL(pg_array) ||
|
||||
(ARR_ELEMTYPE(pg_array) != FLOAT8OID)))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* this copies the coordinates found inside a slot onto an array we own
|
||||
*/
|
||||
bool copy_slot_coordinates_to_array(GSPoint *coordinates, TupleTableSlot const * slot, uint32_t const dimension)
|
||||
{
|
||||
if (unlikely((slot == nullptr) or (coordinates == nullptr)))
|
||||
return false;
|
||||
|
||||
/*
|
||||
* we obtain the coordinates of the current point (function call incurs in detoasting
|
||||
* and thus memory is allocated and we have to free it once we don't need it
|
||||
*/
|
||||
ArrayType *current_point_pgarray = DatumGetArrayTypePCopy(slot->tts_values[0]);
|
||||
bool const valid_point = verify_pgarray(current_point_pgarray, dimension);
|
||||
bool release_point = PointerGetDatum(current_point_pgarray) != slot->tts_values[0];
|
||||
|
||||
/*
|
||||
* if the point is not valid and it was originally toasted, then we release the copy
|
||||
*/
|
||||
if (unlikely(!valid_point && release_point)) {
|
||||
pfree(current_point_pgarray);
|
||||
release_point = false;
|
||||
}
|
||||
|
||||
coordinates->pg_coordinates = current_point_pgarray;
|
||||
coordinates->should_free = release_point;
|
||||
|
||||
return valid_point;
|
||||
}
|
||||
|
||||
/*
|
||||
* this sets the weights of a set of candidates to 1 (every point is the centroid of itself)
|
||||
*/
|
||||
void reset_weights(List const * centroids)
|
||||
{
|
||||
ListCell const * current_centroid_cell = centroids ? centroids->head : nullptr;
|
||||
GSPoint *centroid = nullptr;
|
||||
|
||||
for (; current_centroid_cell != nullptr; current_centroid_cell = lnext(current_centroid_cell)) {
|
||||
centroid = reinterpret_cast<GSPoint *>(lfirst(current_centroid_cell));
|
||||
centroid->weight = 1U;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* given a set of centroids (as a PG list) and a point, this function compute the distance to the closest
|
||||
* centroid
|
||||
*/
|
||||
bool closest_centroid(List const * centroids, GSPoint const * point, uint32_t const dimension, double *distance)
|
||||
{
|
||||
ListCell const * current_centroid_cell = centroids ? centroids->head : nullptr;
|
||||
GSPoint *centroid = nullptr;
|
||||
GSPoint *closest_centroid_ptr = nullptr;
|
||||
bool result = false;
|
||||
bool min_distance_changed = false;
|
||||
double local_distance = 0.;
|
||||
auto min_distance = DBL_MAX;
|
||||
auto const * point_coordinates = reinterpret_cast<double const *>(ARR_DATA_PTR(point->pg_coordinates));
|
||||
|
||||
for (; current_centroid_cell != nullptr; current_centroid_cell = lnext(current_centroid_cell)) {
|
||||
/*
|
||||
* low temporal locality for a prefetch for read
|
||||
*/
|
||||
prefetch(lnext(current_centroid_cell), 0, 1);
|
||||
centroid = reinterpret_cast<GSPoint *>(lfirst(current_centroid_cell));
|
||||
local_distance = l2_squared(point_coordinates,
|
||||
reinterpret_cast<double const *>(ARR_DATA_PTR(centroid->pg_coordinates)), dimension);
|
||||
min_distance_changed = local_distance < min_distance;
|
||||
min_distance = min_distance_changed ? local_distance : min_distance;
|
||||
closest_centroid_ptr = min_distance_changed ? centroid : closest_centroid_ptr;
|
||||
}
|
||||
|
||||
if (closest_centroid_ptr) {
|
||||
result = true;
|
||||
++closest_centroid_ptr->weight;
|
||||
}
|
||||
|
||||
if (likely(distance != nullptr))
|
||||
*distance = min_distance;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* given a set of centroids (as a PG list) and a set of points, this function computes
|
||||
* the cost of the set of centroids as well as their weights (number of points assigned
|
||||
* to each centroid)
|
||||
* we assumed that all weights have been reset already
|
||||
*/
|
||||
bool compute_cost_and_weights(List const * centroids, GSPoint const * points, uint32_t dimension,
|
||||
uint32_t const num_slots, double *cost)
|
||||
{
|
||||
GSPoint const * point = nullptr;
|
||||
double local_distance = 0.;
|
||||
double cost_of_batch = 0.;
|
||||
double local_op_error = 0.;
|
||||
double total_op_error = 0.;
|
||||
uint32_t current_slot = 0U;
|
||||
bool const result = current_slot < num_slots;
|
||||
|
||||
// for every point we compute the closest centroid and increase the weight of the corresponding centroid
|
||||
while (current_slot < num_slots) {
|
||||
point = points + current_slot;
|
||||
closest_centroid(centroids, point, dimension, &local_distance);
|
||||
twoSum(cost_of_batch, local_distance, &cost_of_batch, &local_op_error);
|
||||
total_op_error += local_op_error;
|
||||
++current_slot;
|
||||
}
|
||||
cost_of_batch += total_op_error;
|
||||
*cost = cost_of_batch;
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* this runs kmeans++ on a super-set of centroids to obtain the k centroids we want as seeding
|
||||
*/
|
||||
List *kmeanspp(KMeansStateDescription *description, List *centroids_candidates, uint32_t const idx_current_centroids,
|
||||
uint32_t const size_centroid_bytes, std::mt19937_64 *prng)
|
||||
{
|
||||
Centroid *centroids = description->centroids[idx_current_centroids];
|
||||
uint32_t const num_centroids_needed = description->num_centroids;
|
||||
uint32_t num_candidates = centroids_candidates ? centroids_candidates->length : 0;
|
||||
uint32_t const dimension = description->dimension;
|
||||
ListCell *current_candidate_cell = nullptr;
|
||||
ListCell *prev_candidate_cell = nullptr;
|
||||
ListCell *tmp_cell = nullptr;
|
||||
std::uniform_real_distribution<double> unit_sampler(0., 1.);
|
||||
// if there are less candidates than centroids needed, then all of them become centroids
|
||||
double sample_probability = num_candidates <= num_centroids_needed ? 1. : 1. / static_cast<double>(num_candidates);
|
||||
double sample_probability_correction = 0.;
|
||||
double candidate_probability = 0.;
|
||||
double distance = 0.;
|
||||
double sum_distances = 0.;
|
||||
double sum_distances_local_correction = 0.;
|
||||
double sum_distances_correction = 0.;
|
||||
ArrayType *current_candidate_pgarray = nullptr;
|
||||
ArrayType *current_centroid_pgarray = nullptr;
|
||||
GSPoint *current_candidate = nullptr;
|
||||
GSPoint *tmp_candidate = nullptr;
|
||||
uint32_t current_centroid_idx = description->current_centroid;
|
||||
uint32_t tries_until_next_centroid = 0;
|
||||
bool no_more_candidates = false;
|
||||
|
||||
/*
|
||||
* we expect to produce all centroids in one go and to be able to produce them because
|
||||
* we have enough candidates
|
||||
*/
|
||||
if ((current_centroid_idx > 0) || (num_centroids_needed == 0))
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("k-means: not able to run k-means++")));
|
||||
/*
|
||||
* we produce the very first centroid uniformly at random (not weighted)
|
||||
* the rest of the centroids are found by sampling w.r.t. the (weighted) distance to their closest centroid
|
||||
* (elements that are far from their centroids are more likely to become centroids themselves,
|
||||
* but heavier points also have better chances)
|
||||
*/
|
||||
while (likely((current_centroid_idx < num_centroids_needed) && !no_more_candidates)) {
|
||||
++tries_until_next_centroid;
|
||||
/*
|
||||
* the probability of no choosing a centroid in a round is > 0
|
||||
* so we loop until we have found all our k centroids
|
||||
* in each round one centroid will be found at most
|
||||
*/
|
||||
prev_candidate_cell = nullptr;
|
||||
no_more_candidates = true;
|
||||
current_candidate_cell = centroids_candidates ? centroids_candidates->head : nullptr;
|
||||
for (; current_candidate_cell != nullptr; current_candidate_cell = lnext(current_candidate_cell)) {
|
||||
current_candidate = reinterpret_cast<GSPoint *>(lfirst(current_candidate_cell));
|
||||
current_candidate_pgarray = current_candidate->pg_coordinates;
|
||||
|
||||
no_more_candidates = false;
|
||||
candidate_probability = unit_sampler(*prng);
|
||||
|
||||
/*
|
||||
* the very first candidate will be sampled uniformly, the rest will be sampled
|
||||
* w.r.t. the distance to their closest centroid (the farther the more probable)
|
||||
*/
|
||||
if (likely(current_centroid_idx > 0)) {
|
||||
// this distance is already weighted and thus requires no weighting again
|
||||
distance = current_candidate->distance_to_closest_centroid;
|
||||
twoDiv(distance, sum_distances, &sample_probability, &sample_probability_correction);
|
||||
sample_probability += sample_probability_correction;
|
||||
}
|
||||
|
||||
/*
|
||||
* this is weighted sampling (taking into consideration the weight of the point)
|
||||
*/
|
||||
if (candidate_probability >= sample_probability) {
|
||||
prev_candidate_cell = current_candidate_cell;
|
||||
continue;
|
||||
}
|
||||
|
||||
/*
|
||||
* this candidate becomes a centroid
|
||||
*/
|
||||
current_centroid_pgarray = centroids[current_centroid_idx].coordinates;
|
||||
|
||||
/*
|
||||
* we copy the coordinates of the centroid to the official set of centroids
|
||||
*/
|
||||
if (unlikely((current_centroid_pgarray != nullptr) and (current_candidate_pgarray != nullptr))) {
|
||||
auto point_coordinates_to = reinterpret_cast<double *>(ARR_DATA_PTR(current_centroid_pgarray));
|
||||
auto point_coordinates_from = reinterpret_cast<double *>(ARR_DATA_PTR(current_candidate_pgarray));
|
||||
memset_s(point_coordinates_to, size_centroid_bytes, 0, size_centroid_bytes);
|
||||
errno_t rc = memcpy_s(point_coordinates_to,
|
||||
size_centroid_bytes, point_coordinates_from, size_centroid_bytes);
|
||||
securec_check(rc, "\0", "\0");
|
||||
}
|
||||
|
||||
/*
|
||||
* we delete the element that just became centroid from the list of candidates (along all its information)
|
||||
*/
|
||||
centroids_candidates = list_delete_cell(centroids_candidates, current_candidate_cell, prev_candidate_cell);
|
||||
pfree(current_candidate->pg_coordinates);
|
||||
current_candidate_cell =
|
||||
prev_candidate_cell ? prev_candidate_cell : centroids_candidates ? centroids_candidates->head : nullptr;
|
||||
|
||||
/*
|
||||
* we can reset sum_distances because it will be overwritten below anyway
|
||||
*/
|
||||
sum_distances = 0.;
|
||||
|
||||
/*
|
||||
* we update the distance to the closest centroid depending on the presence of the new
|
||||
* centroid
|
||||
*/
|
||||
tmp_cell = centroids_candidates ? centroids_candidates->head : nullptr;
|
||||
for (; tmp_cell != nullptr; tmp_cell = lnext(tmp_cell)) {
|
||||
tmp_candidate = reinterpret_cast<GSPoint *>(lfirst(tmp_cell));
|
||||
distance = l2_squared(reinterpret_cast<double *>(ARR_DATA_PTR(tmp_candidate->pg_coordinates)),
|
||||
reinterpret_cast<double *>(ARR_DATA_PTR(current_centroid_pgarray)), dimension);
|
||||
|
||||
/*
|
||||
* medium temporal locality for a prefetch for write
|
||||
*/
|
||||
prefetch(lnext(tmp_cell), 1, 2);
|
||||
|
||||
/*
|
||||
* since we are dealing with weighted points, the overall sum of distances has
|
||||
* to consider the weight of each point
|
||||
*/
|
||||
twoMult(distance, tmp_candidate->weight, &distance, &sum_distances_correction);
|
||||
distance += sum_distances_correction;
|
||||
|
||||
/*
|
||||
* we store the weighted distance at the point to save the multiplication later on
|
||||
* when sampling
|
||||
* only if the new centroid becomes the closest one to a candidate we update
|
||||
*/
|
||||
if ((current_centroid_idx == 0) || (distance < tmp_candidate->distance_to_closest_centroid))
|
||||
tmp_candidate->distance_to_closest_centroid = distance;
|
||||
|
||||
/*
|
||||
* every distance appears as many times as the weight of the point
|
||||
*/
|
||||
twoSum(sum_distances, tmp_candidate->distance_to_closest_centroid, &sum_distances,
|
||||
&sum_distances_local_correction);
|
||||
sum_distances_correction += sum_distances_local_correction;
|
||||
}
|
||||
/*
|
||||
* high temporal locality for a prefetch for read
|
||||
*/
|
||||
if (likely(current_candidate_cell != nullptr))
|
||||
prefetch(lnext(current_candidate_cell), 0, 3);
|
||||
|
||||
sum_distances += sum_distances_correction;
|
||||
|
||||
++current_centroid_idx;
|
||||
tries_until_next_centroid = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* we get rid of the linked list of candidates in case some candidates are left
|
||||
*/
|
||||
while (centroids_candidates) {
|
||||
current_candidate_cell = centroids_candidates->head;
|
||||
current_candidate = reinterpret_cast<GSPoint *>(lfirst(current_candidate_cell));
|
||||
/*
|
||||
* low temporal locality for a prefetch for read
|
||||
*/
|
||||
prefetch(lnext(current_candidate_cell), 0, 1);
|
||||
/*
|
||||
* this frees the list cell
|
||||
*/
|
||||
centroids_candidates = list_delete_cell(centroids_candidates, current_candidate_cell, nullptr);
|
||||
/*
|
||||
* this frees the PG array inside the GSPoint
|
||||
*/
|
||||
pfree(current_candidate->pg_coordinates);
|
||||
/*
|
||||
* and finally, this free ths GSPoint
|
||||
*/
|
||||
pfree(current_candidate);
|
||||
}
|
||||
|
||||
description->current_centroid = current_centroid_idx;
|
||||
|
||||
return centroids_candidates;
|
||||
}
|
||||
|
||||
/*
|
||||
* this aggregates the value of the function over all centroids
|
||||
*/
|
||||
void compute_cost(KMeansStateDescription *description, uint32_t const idx_current_centroids)
|
||||
{
|
||||
Centroid *centroids = description->centroids[idx_current_centroids];
|
||||
uint32_t const num_centroids = description->num_centroids;
|
||||
IncrementalStatistics *output_statistics = description->solution_statistics + idx_current_centroids;
|
||||
output_statistics->reset();
|
||||
|
||||
for (uint32_t c = 0; c < num_centroids; ++c)
|
||||
*output_statistics += centroids[c].statistics;
|
||||
}
|
||||
|
||||
/*
|
||||
* every iteration there is a set of centroids (the next one) that is reset
|
||||
*/
|
||||
void reset_centroids(KMeansStateDescription *description, uint32_t const idx_centroids,
|
||||
uint32_t const size_centroid_bytes)
|
||||
{
|
||||
uint32_t const num_centroids = description->num_centroids;
|
||||
Centroid *centroids = description->centroids[idx_centroids];
|
||||
Centroid *centroid = nullptr;
|
||||
double *centroid_coordinates = nullptr;
|
||||
|
||||
for (uint32_t c = 0; c < num_centroids; ++c) {
|
||||
centroid = centroids + c;
|
||||
centroid->statistics.reset();
|
||||
centroid_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(centroid->coordinates));
|
||||
memset_s(centroid_coordinates, size_centroid_bytes, 0, size_centroid_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* given the running mean of a centroid and a new point, this adds the new point to the aggregate
|
||||
* using a sum that provides higher precision (we could provide much higher precision at the cost
|
||||
* of allocating yet another array to keep correction terms for every dimension
|
||||
*/
|
||||
force_inline void aggregate_point(double *centroid_aggregation, double const * new_point, uint32_t const dimension)
|
||||
{
|
||||
double local_correction = 0;
|
||||
for (uint32_t d = 0; d < dimension; ++d) {
|
||||
twoSum(centroid_aggregation[d], new_point[d], centroid_aggregation + d, &local_correction);
|
||||
centroid_aggregation[d] += local_correction;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* this produces the centroid by dividing the aggregate by the amount of points it got assigned
|
||||
* we assumed that population > 0
|
||||
*/
|
||||
force_inline void finish_centroid(double *centroid_aggregation, uint32_t const dimension, double const population)
|
||||
{
|
||||
double local_correction = 0.;
|
||||
for (uint32_t d = 0; d < dimension; ++d) {
|
||||
twoDiv(centroid_aggregation[d], population, centroid_aggregation + d, &local_correction);
|
||||
centroid_aggregation[d] += local_correction;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* updates the minimum bounding box to contain the new given point
|
||||
*/
|
||||
force_inline void update_bbox(double * const bbox_min, double * const bbox_max, double const * point,
|
||||
uint32_t const dimension)
|
||||
{
|
||||
uint32_t current_dimension = 0;
|
||||
double min = 0.;
|
||||
double max = 0.;
|
||||
double p = 0.;
|
||||
|
||||
while (current_dimension < dimension) {
|
||||
min = bbox_min[current_dimension];
|
||||
max = bbox_max[current_dimension];
|
||||
p = point[current_dimension];
|
||||
|
||||
/*
|
||||
* we do cmovs instead of ifs (we could spare a comparison from time to time
|
||||
* by using ifs, but we increase branch misprediction as well. thus we settle
|
||||
* for the branchless option (more expensive than a hit, but cheaper than a miss)
|
||||
*/
|
||||
bbox_min[current_dimension] = p < min ? p : min;
|
||||
bbox_max[current_dimension] = p > max ? p : max;
|
||||
|
||||
++current_dimension;
|
||||
}
|
||||
}
|
||||
|
||||
bool init_kmeans(KMeansStateDescription *description, double *bbox_min, double *bbox_max, GSPoint const * batch,
|
||||
uint32_t const num_slots, uint32_t const size_centroid_bytes)
|
||||
{
|
||||
uint32_t const dimension = description->dimension;
|
||||
bool const first_run = description->current_iteration == 0;
|
||||
uint32_t current_slot = 0;
|
||||
GSPoint const * current_point = nullptr;
|
||||
double const * current_point_coordinates = nullptr;
|
||||
|
||||
if (unlikely(num_slots == 0))
|
||||
return false;
|
||||
|
||||
/*
|
||||
* the very first slot of the batch is a bit special. observe that we have got here
|
||||
* after the point has passed validity checks and thus it is safe to access its information
|
||||
* directly
|
||||
*/
|
||||
current_point = batch + current_slot;
|
||||
current_point_coordinates = reinterpret_cast<double const *>(ARR_DATA_PTR(current_point->pg_coordinates));
|
||||
|
||||
/*
|
||||
* in the very first run we set the coordinates of the bounding box as the ones
|
||||
* of the very first point (we improve from there)
|
||||
*/
|
||||
if (unlikely(first_run)) {
|
||||
/*
|
||||
* no need to memset to zero because in the very first run they are freshly allocated
|
||||
* with palloc0
|
||||
*/
|
||||
errno_t rc = memcpy_s(bbox_min, size_centroid_bytes, current_point_coordinates, size_centroid_bytes);
|
||||
securec_check(rc, "\0", "\0");
|
||||
rc = memcpy_s(bbox_max, size_centroid_bytes, current_point_coordinates, size_centroid_bytes);
|
||||
securec_check(rc, "\0", "\0");
|
||||
description->current_iteration = 1;
|
||||
} else {
|
||||
update_bbox(bbox_min, bbox_max, current_point_coordinates, dimension);
|
||||
}
|
||||
|
||||
++description->num_good_points;
|
||||
|
||||
/*
|
||||
* let's consider the rest of the batch
|
||||
*/
|
||||
while (likely(++current_slot < num_slots)) {
|
||||
current_point = batch + current_slot;
|
||||
current_point_coordinates = reinterpret_cast<double const *>(ARR_DATA_PTR(current_point->pg_coordinates));
|
||||
|
||||
update_bbox(bbox_min, bbox_max, current_point_coordinates, dimension);
|
||||
|
||||
++description->num_good_points;
|
||||
}
|
||||
|
||||
/*
|
||||
* done with the batch
|
||||
*/
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* we assume that all slots in the batch are non-null (guaranteed by the upper call)
|
||||
* also, that the next set of centroids has been reset previous to the very first call
|
||||
*/
|
||||
void update_centroids(KMeansStateDescription *description, GSPoint *slots, uint32_t const num_slots,
|
||||
uint32_t const idx_current_centroids, uint32_t const idx_next_centroids)
|
||||
{
|
||||
uint32_t const dimension = description->dimension;
|
||||
uint32_t current_slot = 0U;
|
||||
uint32_t current_centroid = 0U;
|
||||
uint32_t const num_centroids = description->num_centroids;
|
||||
uint32_t closest_centroid = 0U;
|
||||
GSPoint const * current_point = nullptr;
|
||||
double const * current_point_coordinates = nullptr;
|
||||
double *current_centroid_coordinates = nullptr;
|
||||
double *next_centroid_coordinates = nullptr;
|
||||
double dist = 0.;
|
||||
auto min_dist = DBL_MAX;
|
||||
Centroid *current_centroids = description->centroids[idx_current_centroids];
|
||||
Centroid *next_centroids = description->centroids[idx_next_centroids];
|
||||
Centroid *closest_centroid_ptr = nullptr;
|
||||
Centroid *closest_centroid_next_ptr = nullptr;
|
||||
bool min_dist_change = false;
|
||||
IncrementalStatistics local_statistics;
|
||||
|
||||
/*
|
||||
* just in case, but this should not happen as we control the parent call
|
||||
*/
|
||||
if (unlikely(num_slots == 0))
|
||||
return;
|
||||
|
||||
do {
|
||||
current_centroid = 0U;
|
||||
min_dist = DBL_MAX;
|
||||
/*
|
||||
* we obtain the coordinates of the current point
|
||||
*/
|
||||
current_point = slots + current_slot;
|
||||
|
||||
/*
|
||||
* this loops obtains the distance of the current point to all centroids and keeps the closest one
|
||||
*/
|
||||
while (likely(current_centroid < num_centroids)) {
|
||||
current_centroid_coordinates =
|
||||
reinterpret_cast<double *>(ARR_DATA_PTR(current_centroids[current_centroid].coordinates));
|
||||
current_point_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(current_point->pg_coordinates));
|
||||
|
||||
dist = description->distance(current_point_coordinates, current_centroid_coordinates, dimension);
|
||||
min_dist_change = dist < min_dist;
|
||||
closest_centroid = min_dist_change ? current_centroid : closest_centroid;
|
||||
min_dist = min_dist_change ? dist : min_dist;
|
||||
++current_centroid;
|
||||
}
|
||||
|
||||
/*
|
||||
* once the closest centroid has been detected we proceed with the aggregation and update
|
||||
* of statistics
|
||||
*/
|
||||
local_statistics.setTotal(min_dist);
|
||||
local_statistics.setMin(min_dist);
|
||||
local_statistics.setMax(min_dist);
|
||||
local_statistics.setPopulation(1ULL);
|
||||
closest_centroid_ptr = current_centroids + closest_centroid;
|
||||
closest_centroid_next_ptr = next_centroids + closest_centroid;
|
||||
closest_centroid_ptr->statistics += local_statistics;
|
||||
/*
|
||||
* for the next iteration (if there is any) we have to obtain a new centroid, which will be
|
||||
* the average of the points that we aggregate here
|
||||
*/
|
||||
next_centroid_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(closest_centroid_next_ptr->coordinates));
|
||||
aggregate_point(next_centroid_coordinates, current_point_coordinates, dimension);
|
||||
} while (likely(++current_slot < num_slots));
|
||||
}
|
||||
|
||||
void merge_centroids(KMeansStateDescription *description, uint32_t const idx_current_centroids,
|
||||
uint32_t const idx_next_centroids, uint32_t const size_centroid_bytes)
|
||||
{
|
||||
uint32_t const num_centroids = description->num_centroids;
|
||||
uint32_t const dimension = description->dimension;
|
||||
Centroid * const current_centroids = description->centroids[idx_current_centroids];
|
||||
Centroid * const next_centroids = description->centroids[idx_next_centroids];
|
||||
Centroid *current_centroid = nullptr;
|
||||
Centroid *next_centroid = nullptr;
|
||||
double *current_centroid_coordinates = nullptr;
|
||||
double *next_centroid_coordinates = nullptr;
|
||||
|
||||
for (uint32_t c = 0; c < num_centroids; ++c) {
|
||||
next_centroid = next_centroids + c;
|
||||
current_centroid = current_centroids + c;
|
||||
next_centroid_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(next_centroid->coordinates));
|
||||
/*
|
||||
* if the cluster of a centroid is empty we copy it from the previous iteration verbatim
|
||||
* to not loose it
|
||||
*/
|
||||
if (unlikely(current_centroid->statistics.getPopulation() == 0)) {
|
||||
current_centroid_coordinates = reinterpret_cast<double *>(ARR_DATA_PTR(current_centroid->coordinates));
|
||||
/*
|
||||
* observe that we do not have to memset to zero because current_centroid was reset before the run
|
||||
* and since no point was assigned to it, it has remained reset
|
||||
*/
|
||||
error_t rc = memcpy_s(next_centroid_coordinates, size_centroid_bytes, current_centroid_coordinates,
|
||||
size_centroid_bytes);
|
||||
securec_check(rc, "\0", "\0");
|
||||
} else {
|
||||
finish_centroid(next_centroid_coordinates, dimension, current_centroid->statistics.getPopulation());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool finish_kmeans()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
ModelPredictor kmeans_predict_prepare(Model const * model)
|
||||
{
|
||||
return reinterpret_cast<ModelPredictor>(const_cast<Model *>(model));
|
||||
}
|
||||
|
||||
Datum kmeans_predict(ModelPredictor model, Datum *data, bool *nulls, Oid *types, int32_t nargs)
|
||||
{
|
||||
auto kmeans_model = reinterpret_cast<ModelKMeans *>(model);
|
||||
|
||||
/*
|
||||
* sanity checks
|
||||
*/
|
||||
if (unlikely(nargs != 1))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("k-means predict: only a single attribute containing the coordinates is accepted")));
|
||||
|
||||
if (unlikely(types[0] != FLOAT8ARRAYOID))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("k-means predict: only double precision array of coordinates is accepted")));
|
||||
|
||||
if (unlikely(nulls[0]))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("k-means predict: array of coordinates cannot be null")));
|
||||
|
||||
uint32_t const num_centroids = kmeans_model->actual_num_centroids;
|
||||
uint32_t const dimension = kmeans_model->dimension;
|
||||
double (*distance)(double const *, double const *, uint32_t const) = nullptr;
|
||||
|
||||
if (unlikely(num_centroids == 0))
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("k-means predict: number of centroids must be positive")));
|
||||
|
||||
switch (kmeans_model->distance_function_id) {
|
||||
case KMEANS_L1:
|
||||
distance = l1;
|
||||
break;
|
||||
case KMEANS_L2:
|
||||
distance = l2;
|
||||
break;
|
||||
case KMEANS_L2_SQUARED:
|
||||
distance = l2_squared;
|
||||
break;
|
||||
case KMEANS_LINF:
|
||||
distance = linf;
|
||||
break;
|
||||
default:
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("k-means predict: distance function id %u not recognized", kmeans_model->distance_function_id)));
|
||||
}
|
||||
|
||||
WHCentroid *current_centroid = nullptr;
|
||||
WHCentroid *closest_centroid = nullptr;
|
||||
auto input_point_pg_array = DatumGetArrayTypeP(data[0]);
|
||||
auto min_distance = DBL_MAX;
|
||||
double local_distance = 0.;
|
||||
double const * input_point_coordinates = nullptr;
|
||||
int32_t closest_centroid_id = -1;
|
||||
bool const valid_input = verify_pgarray(input_point_pg_array, dimension);
|
||||
bool min_distance_changed = false;
|
||||
|
||||
if (unlikely(!valid_input))
|
||||
return Int32GetDatum(closest_centroid_id);
|
||||
|
||||
input_point_coordinates = reinterpret_cast<double const *>(ARR_DATA_PTR(input_point_pg_array));
|
||||
|
||||
for (uint32_t c = 0; c < num_centroids; ++c) {
|
||||
current_centroid = kmeans_model->centroids + c;
|
||||
local_distance = distance(input_point_coordinates, current_centroid->coordinates, dimension);
|
||||
min_distance_changed = local_distance < min_distance;
|
||||
min_distance = min_distance_changed ? local_distance : min_distance;
|
||||
closest_centroid = min_distance_changed ? current_centroid : closest_centroid;
|
||||
}
|
||||
|
||||
closest_centroid_id = closest_centroid->id;
|
||||
|
||||
/*
|
||||
* for the time being there is no other way to get to the computed distance other than by a log
|
||||
*/
|
||||
return Int32GetDatum(closest_centroid_id);
|
||||
}
|
|
@ -0,0 +1,318 @@
|
|||
# *DB4AI Snapshots* for relational dataset versioning
|
||||
|
||||
## 1. Introduction
|
||||
The module *DB4AI Snapshots* provides a robust and efficient framework that allows any database user to manage multiple versions of relational datasets through a convenient API.
|
||||
|
||||
### 1.1 Compact storage with efficient access
|
||||
Its main benefits are automated management of dependencies among large, versioned datasets and their compact storage. Therefore, *DB4AI Snapshots* leverages redundancies among different versions of data for providing a high level of compaction while minimizing adverse effects on performance. At the same time, snapshot data remains efficiently accessible using the full expressiveness of standard SQL.
|
||||
|
||||
### 1.2 Immutable snapshot data
|
||||
The snapshot API prevents the user from changing versioned data. Similar to a code repository, every change of data will generate a new version, i.e. a new snapshot.
|
||||
|
||||
### 1.3 Performance
|
||||
In addition to compact data representation, the primary goal of *DB4AI Snapshots* is high read performance, i.e. when used in highly repetitive and concurrent read operations for serving as training data sets for concurrent training of multiple ML models.
|
||||
|
||||
As a secondary performance goal, *DB4AI Snapshots* provides efficient data manipulation for creating and manipulation huge volumes of versioned relational data.
|
||||
|
||||
### 1.4 Documentation and Automation
|
||||
In addition, *DB4AI Snapshots* maintains a full documentation of the origin of any dataset, providing lineage and provenance for data, e.g. to be used in the context of reproducible training of ML models. In addition, *DB4AI Snapshots* facilitates automation, i.e. when applying repetitive transformation steps in data cleansing, or for automatically updating an existing training dataset with new data.
|
||||
|
||||
## 2. Quick start
|
||||
### 2.1 Setup
|
||||
|
||||
*DB4AI Snapshots* is automatically installed in every new database instance of openGauss.
|
||||
Therefore the CREATE DATABASE procedure creates the database schema *db4ai* within the new database and populates it with objects required for managing snapshot data.
|
||||
|
||||
After successful database creation, any user may start exploring the snapshot functionality. No additional privileges are required.
|
||||
|
||||
### 2.2 Examples
|
||||
|
||||
Set snapshot mode to compact storage model *CSS* (**C**omputed **S**nap**S**hot mode).
|
||||
|
||||
SET db4ai_snapshot_mode = CSS;
|
||||
|
||||
Create a snapshot 'm0@1.0.0' from existing data, where stored in table 'test_data'. The SQL statement may use arbitrary joins and mappings for defining the snapshot.
|
||||
|
||||
CREATE SNAPSHOT m0 AS
|
||||
SELECT a1, a3, a6, a8, a2, pk, a1 b9, a7, a5 FROM test_data;
|
||||
|
||||
Create snapshot 'm0@2.0.0' from snapshot 'm0@1.0.0' by applying arbitrary DML and DDL statements.
|
||||
The new version number indicates a snapshot schema revision by means of at least one ALTER statement.
|
||||
|
||||
CREATE SNAPSHOT m0 FROM @1.0.0 USING (
|
||||
UPDATE SNAPSHOT SET a1 = 5 WHERE pk % 10 = 0;
|
||||
ALTER SNAPSHOT ADD " u+ " INTEGER, ADD "x <> y"INT DEFAULT 2, ADD t CHAR(10) DEFAULT '',
|
||||
DROP a2, DROP COLUMN IF EXISTS b9, DROP COLUMN IF EXISTS b10;
|
||||
UPDATE SNAPSHOT SET "x <> y" = 8 WHERE pk < 100;
|
||||
ALTER SNAPSHOT DROP " u+ ", DROP IF EXISTS " x+ ";
|
||||
DELETE FROM SNAPSHOT WHERE pk = 3
|
||||
);
|
||||
|
||||
Create snapshot 'm0@2.0.1' from snapshot 'm0@2.0.0' by UPDATE while using a reference to another table. The new version number indicates an update operation (minor data patch). This example uses an AS clause for introducing 'i' as the snapshot's custom correlation name for joining with tables during the UPDATE operation.
|
||||
|
||||
CREATE SNAPSHOT m0 FROM @2.0.0 USING (
|
||||
UPDATE SNAPSHOT AS i SET a5 = o.a2 FROM test_data o
|
||||
WHERE i.pk = o.pk AND o.a3 % 8 = 0
|
||||
);
|
||||
|
||||
Create snapshot 'm0@2.1.0' from snapshot 'm0@2.0.1' by DELETE while using a reference to another table. The new version number indicates a data revision. This example uses the snapshot's default correlation name 'SNAPSHOT' for joining with another table.
|
||||
|
||||
CREATE SNAPSHOT m0 FROM @2.0.1 USING (
|
||||
DELETE FROM SNAPSHOT USING test_data o
|
||||
WHERE SNAPSHOT.pk = o.pk AND o.A7 % 2 = 0
|
||||
);
|
||||
|
||||
Create snapshot 'm0@2.2.0' from snapshot 'm0@2.1.0' by inserting new data. The new version number indicates another data revision.
|
||||
|
||||
CREATE SNAPSHOT m0 FROM @2.1.0 USING (
|
||||
INSERT INTO SNAPSHOT SELECT a1, a3, a6, a8, a2, pk+1000 pk, a7, a5, a4
|
||||
FROM test_data WHERE pk % 10 = 4
|
||||
);
|
||||
|
||||
The SQL syntax was extended with the new @ operator in relation names, allowing the user to specify a snapshot with version throughout SQL. Internally, snapshots are stored as views, where the actual name is generated according to GUC parameters on arbitrary level, e.g. on database level using the current setting of the version delimiter:
|
||||
|
||||
-- DEFAULT db4ai_snapshot_version_delimiter IS '@'
|
||||
ALTER DATABASE <name> SET db4ai_snapshot_version_delimiter = '#';
|
||||
|
||||
Similarly the version separator can be changed by the user:
|
||||
|
||||
-- DEFAULT db4ai_snapshot_version_separator IS ’.’
|
||||
ALTER DATABASE <name> SET db4ai_snapshot_version_separator = ’_’;
|
||||
|
||||
Independently from the GUC parameter settings mentioned above, any snapshot version can be accessed:
|
||||
|
||||
-- standard version string @schema.revision.patch:
|
||||
SELECT * FROM public.data@1.2.3;
|
||||
-- user-defined version strings:
|
||||
SELECT * FROM accounting.invoice@2021;
|
||||
SELECT * FROM user.cars@cleaned;
|
||||
-- quoted identifier for blanks, keywords, special characters, etc.:
|
||||
SELECT * FROM user.cars@"rev 1.1";
|
||||
-- or string literal:
|
||||
SELECT * FROM user.cars@'rev 1.1';
|
||||
|
||||
Alternative, using internal name (depends on GUC settings):
|
||||
|
||||
-- With internal view name, using default GUC settings
|
||||
SELECT * FROM public."data@1.2.3";
|
||||
|
||||
-- With internal view name, using custom GUC settings, as above
|
||||
SELECT * FROM public.data#1_2_3;
|
||||
|
||||
## 3. Privileges
|
||||
All members of role **PUBLIC** may use **DB4AI Snapshots**.
|
||||
|
||||
## 4. Supported Systems and SQL compatibility modes
|
||||
The current version of *DB4AI Snapshots* is tested in openGauss SQL compatibility modes A, B and C.
|
||||
|
||||
## 5. Portability
|
||||
*DB4AI Snapshots* uses standard SQL for implementing its functionality.
|
||||
|
||||
## 6. Dependencies
|
||||
None.
|
||||
|
||||
## 7. Reference & Documentation
|
||||
|
||||
### 7.1 Configuration parameters
|
||||
|
||||
**DB4AI Snapshots** exposes several configuration parameters, via the system's global unified configuration (GUC) management.
|
||||
Configuration parameters may be set on the scope of functions (CREATE FUNCTION), transactions (SET LOCAL), sessions (SET),
|
||||
user (ALTER USER), database (ALTER DATABASE), or on system-wide scope (postgresql.conf).
|
||||
|
||||
SET [SESSION | LOCAL] configuration_parameter { TO | = } { value | 'value' }
|
||||
CREATE FUNCTION <..> SET configuration_parameter { TO | = } { value | 'value' }
|
||||
ALTER DATABASE name SET configuration_parameter { TO | = } { value | 'value' }
|
||||
ALTER USER name [ IN DATABASE database_name ] SET configuration_parameter { TO | = } { value | 'value' }
|
||||
|
||||
The following snapshot configuration parameters are currently supported:
|
||||
|
||||
#### db4ai_snapshot_mode = { MSS | CSS }
|
||||
This snapshot configuration parameter allows to switch between materialized snapshot (*MSS*) mode, where every new snapshot is created as compressed but fully
|
||||
materialized copy of its parent's data, or computed snapshot (*CSS*) mode. In *CSS* mode, the system attempts to exploit redundancies among dependent snapshot versions
|
||||
for minimizing storage requirements.
|
||||
|
||||
The setting of *db4ai_snapshot_mode* may be adjusted at any time and it will have effect on subsequent snapshot operations within the scope of the new setting.
|
||||
|
||||
Whenever *db4ai_snapshot_mode* is not set in the current scope, it defaults to *MSS*.
|
||||
|
||||
##### Example
|
||||
SET db4ai_snapshot_mode = CSS;
|
||||
|
||||
#### db4ai_snapshot_version_delimiter = value
|
||||
|
||||
This snapshot configuration parameter controls the character that delimits the *snapshot version* postfix within snapshot names.
|
||||
In consequence, the character used as *db4ai_snapshot_version_delimiter* cannot be used in snapshot names, neither in the snapshot
|
||||
name prefix, nor in the snapshot version postfix. Also, the setting of *db4ai_snapshot_version_delimiter* must be distinct from
|
||||
*db4ai_snapshot_version_separator*. Whenever *db4ai_snapshot_version_delimiter* is not set in the current scope, it defaults to
|
||||
the symbol '@' (At-sign).
|
||||
|
||||
*Note:* Snapshots created with different settings of *db4ai_snapshot_version_delimiter* are not compatible among each
|
||||
other. Hence, it is advisable to ensure the setting is stable, i.e. by setting it permanently, e.g. on database scope.
|
||||
|
||||
##### Example
|
||||
ALTER DATABASE name SET db4ai_snapshot_version_delimiter = '#';
|
||||
|
||||
#### db4ai_snapshot_version_separator = value
|
||||
|
||||
This snapshot configuration parameter controls the character that separates the *snapshot version* within snapshot names. In consequence,
|
||||
*db4ai_snapshot_version_separator* must not be set to any character representing a digit [0-9].
|
||||
Also, the setting of *db4ai_snapshot_version_separator* must be distinct from *db4ai_snapshot_version_delimiter*.
|
||||
Whenever *db4ai_snapshot_version_separator* is not set in the current scope, it defaults to punctuation mark '.' (period).
|
||||
|
||||
*Note:* Snapshots created with different settings of *db4ai_snapshot_version_separator* do not support automatic version number
|
||||
generation among each other. Hence, it is advisable to ensure the setting is stable, i.e. by setting it permanently, e.g. on database scope.
|
||||
|
||||
##### Example
|
||||
ALTER DATABASE name SET db4ai_snapshot_version_separator = '_';
|
||||
|
||||
### 7.2 Accessing a snapshot
|
||||
|
||||
Independently from the GUC parameter settings mentioned above, any snapshot version can be accessed:
|
||||
|
||||
SELECT … FROM […,]
|
||||
<snapshot_qualified_name> @ <vconst | ident | sconst>
|
||||
[WHERE …] [GROUP BY …] [HAVING …] [ORDER BY …];
|
||||
|
||||
Alternative, using standard version string as internal name (depends on GUC settings):
|
||||
|
||||
SELECT … FROM […,] <snapshot_qualified_name> INTEGER
|
||||
<db4ai_snapshot_version_delimiter> INTEGER
|
||||
<db4ai_snapshot_version_delimiter> INTEGER
|
||||
[WHERE …] [GROUP BY …] [HAVING …] [ORDER BY …];
|
||||
|
||||
Alternative, using user-defined version string as internal name (depends on GUC settings):
|
||||
|
||||
SELECT … FROM […,]
|
||||
<snapshot_qualified_name> <db4ai_snapshot_version_delimiter> <snapshot_version_string>
|
||||
[WHERE …] [GROUP BY …] [HAVING …] [ORDER BY …];
|
||||
|
||||
If any component of <snapshot_qualified_name>, <db4ai_snapshot_version_delimiter>, <db4ai_snapshot_version_separator>, or <snapshot_version_string> should contain special characters, then quoting of the snapshot name is required.
|
||||
|
||||
### 7.3 Creating a snapshot
|
||||
|
||||
CREATE SNAPSHOT <qualified_name> [@ <version | ident | sconst>]
|
||||
[COMMENT IS <sconst>}
|
||||
AS <SelectStmt>;
|
||||
|
||||
The CREATE SNAPSHOT AS statement is invoked for creating a new snapshot. A caller provides the qualified_name for the snapshot to be created. The \<SelectStmt\> defines the content of the new snapshot in SQL. The optional @ operator allows to assign a custom version number or string to the new snapshot. The default version number is @ 1.0.0.
|
||||
A snapshot may be annotated using the optional COMMENT IS clause.
|
||||
|
||||
**Example:**
|
||||
|
||||
CREATE SNAPSHOT public.cars AS
|
||||
SELECT id, make, price, modified FROM cars_table;
|
||||
|
||||
The CREATE SNAPSHOT AS statement will create the snapshot 'public.cars@1.0.0' by selecting some columns of all the tuples in relation cars_table, which exists in the operational data store.
|
||||
The created snapshot's name 'public.cars' is automatically extended with the suffix '@1.0.0' to the full snapshot name 'public.cars@1.0.0', thereby creating a unique, versioned identifier for the snapshot.
|
||||
|
||||
The DB4AI module of openGauss stores metadata associated with snapshots in a DB4AI catalog table *db4ai.snapshot*. The catalog exposes various metadata about the snapshot, particularly noteworthy is the field 'snapshot_definition' that provides documentation how the snapshot was generated. The DB4AI catalog serves for managing the life cycle of snapshots and allows exploring available snapshots in the system.
|
||||
|
||||
In summary, an invocation of the CREATE SNAPSHOT AS statement will create a corresponding entry in the DB4AI catalog, with a unique snapshot name and documentation of the snapshot's lineage. The new snapshot is in state 'published'. Initial snapshots serve as true and reusable copy of operational data, and as a starting point for subsequent data curation, therefore initial snapshots are already immutable. In addition, the system creates a view with the published snapshot's name, with grantable read-only privileges for the current user. The current user may access the snapshot, using arbitrary SQL statements against this view, or grant read-access privileges to other user for sharing the new snapshot for collaboration. Published snapshots may be used for model training, by using the new snapshot name as input parameter to the *db4ai.train* function of the DB4AI model warehouse. Other users may discover new snapshots by browsing the DB4AI catalog, and if corresponding read access privileges on the snapshot view are granted by the snapshot's creator, collaborative model training using this snapshot as training data can commence.
|
||||
|
||||
### 7.4 Creating a snapshot revision
|
||||
|
||||
CREATE SNAPSHOT <qualified_name> [@ <version | ident | sconst>]
|
||||
FROM @ <version | ident | sconst>
|
||||
[COMMENT IS <sconst>}
|
||||
USING (
|
||||
{ INSERT [INTO SNAPSHOT] …
|
||||
| UPDATE [SNAPSHOT] [AS <alias>] SET … [FROM …] [WHERE …]
|
||||
| DELETE [FROM SNAPSHOT] [AS <alias>] [USING …] [WHERE …]
|
||||
| ALTER [SNAPSHOT] { ADD … | DROP … } [, …]
|
||||
} [; …]
|
||||
);
|
||||
|
||||
The CREATE SNAPSHOT FROM statement serves for creating a modified and immutable snapshot based on an existing snapshot. The parent snapshot is specified by the qualified_name and the version provided in the FROM clause.
|
||||
The new snapshot is created within the parent's schema and it also inherits the prefix of the parent's name, but without the parent's version number. The statements listed in the USING clause define how the parent snapshot shall be modified by means of a batch of SQL DDL and DML statements, i.e. ALTER, INSERT, UPDATE, and DELETE.
|
||||
|
||||
**Examples:**
|
||||
|
||||
CREATE SNAPSHOT public.cars FROM @1.0.0 USING (
|
||||
ALTER ADD year int, DROP make;
|
||||
INSERT SELECT * FROM cars_table WHERE modified=CURRENT_DATE;
|
||||
UPDATE SET year=in.year FROM cars_table in WHERE SNAPSHOT.id=in.id;
|
||||
DELETE WHERE modified<CURRENT_DATE-30
|
||||
); -- Example with 'short SQL' notation
|
||||
|
||||
CREATE SNAPSHOT public.cars FROM @1.0.0 USING (
|
||||
ALTER SNAPSHOT ADD COLUMN year int, DROP COLUMN make;
|
||||
INSERT INTO SNAPSHOT SELECT * FROM cars_table WHERE modified=CURRENT_DATE;
|
||||
UPDATE SNAPSHOT SET year=in.year FROM cars_table in WHERE SNAPHOT.id=in.id;
|
||||
DELETE FROM SNAPSHOT WHERE modified<CURRENT_DATE-30
|
||||
}; -- Example with standard SQL syntax
|
||||
|
||||
In this example, the new snapshot version starts with the current state of snapshot 'public.cars@1.0.0' and adds a new column 'year' to the 'cars' snapshot, while dropping column 'make' that has become irrelevant for the user. This first example uses the short SQL notation, where the individual statements are provided by the user without explicitly stating the snapshot's correlation name. In addition to this syntax, openGauss also accepts standard SQL statements (second example) which tend to be slightly more verbose. Note that both variants allow the introduction of custom correlation names in UPDATE FROM and DELETE USING statements with the AS clause, e.g. `UPDATE AS s [...] WHERE s.id=in.id);` or `UPDATE SNAPSHOT AS s [...] WHERE s.id=in.id);` in the example above.
|
||||
|
||||
The INSERT operation shows an example for pulling fresh data from the operational data store into the new snapshot. The UPDATE operation exemplifies populating the newly added column 'year' with data coming from the operational data store, and finally the DELETE operation demonstrates how to remove obsolete data from a snapshot. The name of the resulting snapshot of this invocation is 'public.cars@2.0.0'. Similar as in CREATE SNAPSHOT AS, the user may override the default version numbering scheme that generates version number '@2.0.0'. The optional COMMENT IS clause allows the user to associate a descriptive textual 'comment' with the unit of work corresponding to this invocation of CREATE SNAPHOT FROM, for improving collaboration and documentation, as well as for change tracking purposes.
|
||||
|
||||
Since all snapshots are immutable by definition, CREATE SNAPSHOT FROM creates a separate snapshot 'public.cars@2.0.0', initially as a logical copy of the parent snapshot 'public.cars@1.0.0', and applies changes from the USING clause, which corresponds to an integral unit of work in data curation. Similar to a SQL script, the batch of operations is executed atomically and consecutively on the logical copy of the parent snapshot. The parent snapshot remains immutable and is completely unaffected by these changes.
|
||||
|
||||
Additionally, the system automatically records all applied changes in the DB4AI snapshot catalog, such that the catalog contains an accurate, complete, and consistent documentation of all changes applied to any snapshot. The catalog also stores the origin and lineage of the created snapshot as a reference to its parent, making provenance of data fully traceable and, as demonstrated later, serves as a repository of data curation operations for supporting automation in snapshot generation.
|
||||
|
||||
The operations themselves allow data scientists to remove columns from the parent snapshot, but also to add and populate new ones, e.g. for the purpose of data annotation. By INSERT, rows may be freely added, e.g. from the operational data source or from other snapshots. Inaccurate or irrelevant data can be deleted, as part of the data cleansing process, regardless of whether the data comes from an immutable parent snapshot or directly from the operational data store. Finally, UPDATE statements allow correction of inaccurate or corrupt data, serve for data imputation of missing data and allow normalization of numeric values to a common scale.
|
||||
|
||||
In summary, the CREATE SNAPSHOT FROM statement was designed for supporting the full spectrum of recurring tasks in data curation:
|
||||
• Data cleansing: Remove or correct irrelevant, inaccurate, or corrupt data
|
||||
• Data Imputation: Fill missing data
|
||||
• Labeling & Annotation: add immutable columns with computed values
|
||||
• Data normalization: Update existing columns to a common scale
|
||||
• Permutation: Support reordering of data for iterative model training
|
||||
• Indexing: Support random access for model training
|
||||
|
||||
Invoking CREATE SNAPSHOT FROM statement allows multiple users to collaborate concurrently in the process of data curation, where each user may break data curation tasks into a set of CREATE SNAPSHOT FROM operations, to be executed in atomic batches. This form of collaboration is similar to software engineers collaborating on a common code repository, but here the concept is extended to include code and data. One invocation of CREATE SNAPSHOT FROM corresponds to a commit operation in a git repository.
|
||||
|
||||
In summary, an invocation of the CREATE SNAPSHOT FROM statement will create a corresponding entry in the DB4AI catalog, with a unique snapshot name and documentation of the snapshot's lineage. The new snapshot remains in state 'unpublished', potentially awaiting further data curation. In addition, the system creates a view with the created snapshot's name, with grantable read-only privileges for the current user. Concurrent calls to CREATE SNAPSHOT FROM are permissive and result in separate new versions originating from the same parent (branches). The current user may access the snapshot using arbitrary, read-only SQL statements against this view, or grant read-access privileges to other user, for sharing the created snapshot and enabling collaboration in data curation. Unpublished snapshots may not participate in model training. Yet, other users may discover unpublished snapshots by browsing the DB4AI catalog, and if corresponding read access privileges on the snapshot view are granted by the snapshot's creator, collaborative data curation using this snapshot can commence.
|
||||
|
||||
### 7.5 Sampling snapshots
|
||||
|
||||
SAMPLE SNAPSHOT <qualified_name> @ <version | ident | sconst>
|
||||
[STRATIFY BY attr_list]
|
||||
{ AS <label> AT RATIO <num> [COMMENT IS <comment>] } [, …]
|
||||
|
||||
The SAMPLE SNAPSHOT statement is used to sample data from a given snapshot (original snapshot) into one or more descendant, but independent snapshots (branches), satisfying a condition given under the parameter 'ratio'.
|
||||
|
||||
**Example:**
|
||||
|
||||
SAMPLE SNAPSHOT public.cars@2.0.0
|
||||
STRATIFY BY color
|
||||
AS _train AT RATIO .8,
|
||||
AS _test AT RATIO .2;
|
||||
|
||||
This invocation of SAMPLE SNAPSHOT creates two snapshots from the snapshot 'cars@2.0.0', one designated for ML model training purposes: 'cars_train@2.0.0' and the other for ML model testing: 'cars_test@2.0.0'. Note that descendant snapshots inherit the parent's schema, name prefix and version suffix, while each sample definition provides a name infix for making descendant snapshot names unique. The AT RATIO clause specifies the ratio of tuples qualifying for the resulting snapshots, namely 80% for training and 20% for testing. The STRATIFY BY clause specifies that the fraction of records for each car color (white, black, red…) is the same in all three participating snapshots.
|
||||
|
||||
### 7.6 Publishing snapshots
|
||||
|
||||
PUBLISH SNAPSHOT <qualified_name> @ <version | ident | sconst>;
|
||||
|
||||
Whenever a snapshot is created with the CREATE SNAPSHOT FROM statement, it is initially unavailable for ML model training. Such snapshots allow users to collaboratively apply further changes in manageable units of work, for facilitating cooperation in data curation. A snapshot is finalized by publishing it via the PUBLISH SNAPSHOT statement. Published snapshots may be used for model training, by using the new snapshot name as input parameter to the *db4ai.train* function of the DB4AI model warehouse. Other users may discover new snapshots by browsing the DB4AI catalog, and if corresponding read access privileges on the snapshot view are granted by the snapshot's creator, collaborative model training using this snapshot as training data can commence.
|
||||
|
||||
**Example:**
|
||||
PUBLISH SNAPSHOT public.cars@2.0.0;
|
||||
|
||||
Above is an exemplary invocation, publishing snapshot 'public.cars@2.0.0'.
|
||||
|
||||
### 7.7 Archiving snapshots
|
||||
|
||||
ARCHIVE SNAPSHOT <qualified_name> @ <version | ident | sconst>;
|
||||
|
||||
Archiving changes the state of any snapshot to 'archived', while the snapshot remains immutable and cannot participate in CREATE SNAPSHOT FROM or *db4ai.train* operations. Archived snapshots may be purged, permanently deleting their data and recovering occupied storage space, or they may be reactivated by invoking PUBLISH SNAPSHOT an archived snapshot.
|
||||
|
||||
**Example:**
|
||||
|
||||
ARCHIVE SNAPSHOT public.cars@2.0.0;
|
||||
|
||||
The example above archives snapshot 'public.cars@2.0.0' that was previously in state 'published' or 'unpublished'.
|
||||
|
||||
### 7.8 Purging snapshots
|
||||
|
||||
PURGE SNAPSHOT <qualified_name> @ <version | ident | sconst>;
|
||||
|
||||
The PURGE SNAPSHOT statement is used to permanently delete all data associated with a snapshot from the system. A prerequisite to purging is that the snapshot is not referenced by any existing trained model in the DB4AI model warehouse. Snapshots still referenced by trained models cannot be purged.
|
||||
|
||||
Purging snapshots without existing descendant snapshots, removes them completely and occupied storage space is recovered. If descendant snapshots exist, the purged snapshot will be merged into adjacent snapshots, such that no information on lineage is lost, but storage efficiency is improved. In any case, the purged snapshot's name becomes invalid and is removed from the system.
|
||||
|
||||
**Example:**
|
||||
|
||||
PURGE SNAPSHOT public.cars@2.0.0;
|
||||
|
||||
The example above recovers storage space occupied by 'public.cars@2.0.0' by removing the snapshot completely.
|
|
@ -0,0 +1,475 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* -------------------------------------------------------------------------
|
||||
*
|
||||
* create.sql
|
||||
* Create DB4AI.Snapshot functionality.
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/snapshots/create.sql
|
||||
*
|
||||
* -------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.create_snapshot_internal(
|
||||
IN s_id BIGINT, -- snapshot id
|
||||
IN i_schema NAME, -- snapshot namespace
|
||||
IN i_name NAME, -- snapshot name
|
||||
IN i_commands TEXT[], -- commands defining snapshot data and layout
|
||||
IN i_comment TEXT, -- snapshot description
|
||||
IN i_owner NAME -- snapshot owner
|
||||
)
|
||||
RETURNS VOID LANGUAGE plpgsql SECURITY DEFINER SET search_path = pg_catalog, pg_temp
|
||||
AS $$
|
||||
DECLARE
|
||||
e_stack_act TEXT; -- current stack for validation
|
||||
dist_cmd TEXT; -- DISTRIBUTE BY translation for backing table
|
||||
row_count BIGINT; -- number of rows in this snapshot
|
||||
BEGIN
|
||||
|
||||
BEGIN
|
||||
RAISE EXCEPTION 'SECURITY_STACK_CHECK';
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
GET STACKED DIAGNOSTICS e_stack_act = PG_EXCEPTION_CONTEXT;
|
||||
|
||||
IF CURRENT_SCHEMA = 'db4ai' THEN
|
||||
e_stack_act := replace(e_stack_act, 'ion cre', 'ion db4ai.cre');
|
||||
END IF;
|
||||
|
||||
IF e_stack_act NOT LIKE E'referenced column: create_snapshot_internal\n'
|
||||
'SQL statement "SELECT db4ai.create_snapshot_internal(s_id, i_schema, i_name, i_commands, i_comment, CURRENT_USER)"\n'
|
||||
'PL/pgSQL function db4ai.create_snapshot(name,name,text[],name,text) line 269 at PERFORM%'
|
||||
THEN
|
||||
RAISE EXCEPTION 'direct call to db4ai.create_snapshot_internal(bigint,name,name,text[],text,name) is not allowed'
|
||||
USING HINT = 'call public interface db4ai.create_snapshot instead';
|
||||
END IF;
|
||||
END;
|
||||
|
||||
IF length(i_commands[3]) > 0 THEN
|
||||
<<translate_dist_by_hash>>
|
||||
DECLARE
|
||||
pattern TEXT; -- current user column name
|
||||
mapping NAME[]; -- mapping user column names to internal backing columns
|
||||
|
||||
quoted BOOLEAN := FALSE; -- inside quoted identifier
|
||||
cur_ch VARCHAR; -- current character in tokenizer
|
||||
idx INTEGER := 0; -- loop counter, cannot use FOR .. iterator
|
||||
tokens TEXT; -- user's column name list in DISTRIBUTE BY HASH()
|
||||
BEGIN
|
||||
|
||||
-- extract mapping from projection list for view definition
|
||||
mapping := array(SELECT unnest(ARRAY[ m[1], coalesce(m[2], replace(m[3],'""','"'))]) FROM regexp_matches(
|
||||
i_commands[5], 't[0-9]+\.(f[0-9]+) AS (?:([^\s",]+)|"((?:[^"]*"")*[^"]*)")', 'g') m);
|
||||
|
||||
-- extract field list from DISTRIBUTE BY clause
|
||||
tokens :=(regexp_matches(i_commands[3], '^\s*DISTRIBUTE\s+BY\s+HASH\s*\((.*)\)\s*$', 'i'))[1];
|
||||
IF tokens IS NULL OR tokens SIMILAR TO '\s*' THEN
|
||||
tokens := (regexp_matches(i_commands[3], '^\s*DISTRIBUTE\s+BY\s+REPLICATION\s*$', 'i'))[1];
|
||||
IF tokens IS NULL OR tokens SIMILAR TO '\s*' THEN
|
||||
RAISE EXCEPTION 'cannot match DISTRIBUTE BY clause'
|
||||
USING HINT = 'currently only DISTRIBUTE BY REPLICATION and DISTRIBUTE BY HASH(column_name [, ...]) supported';
|
||||
END IF;
|
||||
-- no translation required, bail out
|
||||
dist_cmd := ' ' || i_commands[3];
|
||||
EXIT translate_dist_by_hash;
|
||||
END IF;
|
||||
tokens := tokens || ' ';
|
||||
|
||||
-- prepare the translated command
|
||||
dist_cmd = ' DISTRIBUTE BY HASH(';
|
||||
|
||||
-- BEGIN tokenizer code for testing
|
||||
|
||||
pattern := '';
|
||||
|
||||
LOOP
|
||||
idx := idx + 1;
|
||||
cur_ch := substr(tokens, idx, 1);
|
||||
EXIT WHEN cur_ch IS NULL OR cur_ch = '';
|
||||
|
||||
CASE cur_ch
|
||||
WHEN '"' THEN
|
||||
IF quoted AND substr(tokens, idx + 1, 1) = '"' THEN
|
||||
pattern := pattern || '"';
|
||||
idx := idx + 1;
|
||||
ELSE
|
||||
quoted := NOT quoted;
|
||||
END IF;
|
||||
IF quoted THEN
|
||||
CONTINUE;
|
||||
END IF;
|
||||
WHEN ',' THEN
|
||||
IF quoted THEN
|
||||
pattern := pattern || cur_ch;
|
||||
CONTINUE;
|
||||
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
|
||||
pattern := ',';
|
||||
ELSE
|
||||
idx := idx - 1; -- reset on comma for next loop
|
||||
END IF;
|
||||
WHEN ' ', E'\n', E'\t' THEN
|
||||
IF quoted THEN
|
||||
pattern := pattern || cur_ch;
|
||||
CONTINUE;
|
||||
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
|
||||
CONTINUE;
|
||||
END IF;
|
||||
ELSE
|
||||
pattern := pattern || CASE WHEN quoted THEN cur_ch ELSE lower(cur_ch) END;
|
||||
CONTINUE;
|
||||
END CASE;
|
||||
|
||||
-- END tokenizer code for testing
|
||||
|
||||
-- attempt to map the pattern
|
||||
FOR idx IN 2 .. array_length(mapping, 1) BY 2 LOOP
|
||||
IF pattern = mapping[idx] THEN
|
||||
-- apply the mapping
|
||||
dist_cmd := dist_cmd || mapping[idx-1] || ',';
|
||||
pattern := NULL;
|
||||
EXIT;
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
-- check if pattern was mapped
|
||||
IF pattern IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'unable to map field "%" to backing table', pattern;
|
||||
END IF;
|
||||
|
||||
END LOOP;
|
||||
|
||||
IF quoted THEN
|
||||
RAISE EXCEPTION 'unterminated quoted identifier ''%'' at or near: ''%''',
|
||||
substr(pattern, 1, char_length(pattern)-1), i_commands[3];
|
||||
END IF;
|
||||
|
||||
dist_cmd := rtrim(dist_cmd, ',') || ')';
|
||||
END;
|
||||
END IF;
|
||||
|
||||
dist_cmd := ''; -- we silently drop DISTRIBUTE_BY
|
||||
|
||||
EXECUTE 'CREATE TABLE db4ai.t' || s_id || ' WITH (orientation = column, compression = low)' || dist_cmd
|
||||
|| ' AS SELECT ' || i_commands[4] || ' FROM _db4ai_tmp_x' || s_id;
|
||||
EXECUTE 'COMMENT ON TABLE db4ai.t' || s_id || ' IS ''snapshot backing table, root is ' || quote_ident(i_schema)
|
||||
|| '.' || quote_ident(i_name) || '''';
|
||||
EXECUTE 'CREATE VIEW db4ai.v' || s_id || ' WITH(security_barrier) AS SELECT ' || i_commands[5] || ', xc_node_id, ctid FROM db4ai.t' || s_id;
|
||||
EXECUTE 'COMMENT ON VIEW db4ai.v' || s_id || ' IS ''snapshot ' || quote_ident(i_schema) || '.' || quote_ident(i_name)
|
||||
|| ' backed by db4ai.t' || s_id || CASE WHEN length(i_comment) > 0 THEN ' comment is "' || i_comment || '"' ELSE '' END || '''';
|
||||
EXECUTE 'GRANT SELECT ON db4ai.v' || s_id || ' TO ' || i_owner || ' WITH GRANT OPTION';
|
||||
EXECUTE 'SELECT COUNT(*) FROM db4ai.v' || s_id INTO STRICT row_count;
|
||||
|
||||
-- store only original commands supplied by user
|
||||
i_commands := ARRAY[i_commands[1], i_commands[2], i_commands[3]];
|
||||
INSERT INTO db4ai.snapshot (id, root_id, schema, name, owner, commands, comment, published, row_count)
|
||||
VALUES (s_id, s_id, i_schema, i_name, i_owner, i_commands, i_comment, TRUE, row_count);
|
||||
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.create_snapshot(
|
||||
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
|
||||
IN i_name NAME, -- snapshot name
|
||||
IN i_commands TEXT[], -- commands defining snapshot data and layout
|
||||
IN i_vers NAME DEFAULT NULL, -- override version postfix
|
||||
IN i_comment TEXT DEFAULT NULL -- snapshot description
|
||||
)
|
||||
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
|
||||
AS $$
|
||||
DECLARE
|
||||
s_id BIGINT; -- snapshot id
|
||||
s_mode VARCHAR(3); -- current snapshot mode
|
||||
s_vers_del CHAR; -- snapshot version delimiter, default '@'
|
||||
s_vers_sep CHAR; -- snapshot version separator, default '.'
|
||||
separation_of_powers TEXT; --current separation of rights
|
||||
qual_name TEXT; -- qualified snapshot name
|
||||
command_str TEXT; -- command string
|
||||
pattern TEXT; -- command pattern for matching
|
||||
proj_cmd TEXT; -- SELECT clause for create user view command (may be NULL if from_cmd is not NULL)
|
||||
from_cmd TEXT; -- FROM clause for command (may be NULL if proj_cmd is not NULL)
|
||||
dist_cmd TEXT; -- DISTRIBUTE BY clause for command (may be NULL)
|
||||
res db4ai.snapshot_name; -- composite result
|
||||
BEGIN
|
||||
|
||||
-- obtain active message level
|
||||
BEGIN
|
||||
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
|
||||
RAISE INFO 'effective client_min_messages is ''%''', upper(current_setting('db4ai.message_level'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
END;
|
||||
|
||||
-- obtain database state of separation of rights
|
||||
BEGIN
|
||||
separation_of_powers := upper(current_setting('enableSeparationOfDuty'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
separation_of_powers := 'OFF';
|
||||
END;
|
||||
|
||||
IF separation_of_powers NOT IN ('ON', 'OFF') THEN
|
||||
RAISE EXCEPTION 'Uncertain state of separation of rights.';
|
||||
ELSIF separation_of_powers = 'ON' THEN
|
||||
RAISE EXCEPTION 'Snapshot is not supported in separation of rights';
|
||||
END IF;
|
||||
|
||||
-- obtain active snapshot mode
|
||||
BEGIN
|
||||
s_mode := upper(current_setting('db4ai_snapshot_mode'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_mode := 'MSS';
|
||||
END;
|
||||
|
||||
IF s_mode NOT IN ('CSS', 'MSS') THEN
|
||||
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
|
||||
END IF;
|
||||
|
||||
-- obtain relevant configuration parameters
|
||||
BEGIN
|
||||
s_vers_del := current_setting('db4ai_snapshot_version_delimiter');
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_del := '@';
|
||||
END;
|
||||
BEGIN
|
||||
s_vers_sep := current_setting('db4ai_snapshot_version_separator');
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_sep := '.';
|
||||
END;
|
||||
|
||||
-- check all input parameters
|
||||
IF i_schema IS NULL OR i_schema = '' THEN
|
||||
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
|
||||
END IF;
|
||||
|
||||
IF i_name IS NULL OR i_name = '' THEN
|
||||
RAISE EXCEPTION 'i_name cannot be NULL or empty';
|
||||
ELSIF strpos(i_name, s_vers_del) > 0 THEN
|
||||
RAISE EXCEPTION 'i_name must not contain ''%'' characters', s_vers_del;
|
||||
END IF;
|
||||
|
||||
-- PG BUG: array_ndims('{}') or array_dims(ARRAY[]::INT[]) returns NULL
|
||||
IF i_commands IS NULL OR array_length(i_commands, 1) IS NULL OR array_length(i_commands, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_commands array malformed'
|
||||
USING HINT = 'pass SQL commands as TEXT[] literal, e.g. ''{SELECT *, FROM public.t, DISTRIBUTE BY HASH(id)''';
|
||||
END IF;
|
||||
|
||||
FOREACH command_str IN ARRAY i_commands LOOP
|
||||
IF command_str IS NULL THEN
|
||||
RAISE EXCEPTION 'i_commands array contains NULL values';
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
FOREACH command_str IN ARRAY i_commands LOOP
|
||||
command_str := btrim(command_str);
|
||||
pattern := upper(regexp_replace(left(command_str, 30), '\s+', ' ', 'g'));
|
||||
IF left(pattern, 7) = 'SELECT ' THEN
|
||||
IF proj_cmd IS NULL THEN
|
||||
proj_cmd := command_str;
|
||||
|
||||
DECLARE
|
||||
nested INT := 0; -- level of nesting
|
||||
quoted BOOLEAN := FALSE; -- inside quoted identifier
|
||||
cur_ch VARCHAR; -- current character in tokenizer
|
||||
idx INTEGER := 0; -- loop counter, cannot use FOR .. iterator
|
||||
start INTEGER := 1;
|
||||
stmt TEXT := command_str;
|
||||
BEGIN
|
||||
|
||||
-- BEGIN splitter code for testing
|
||||
|
||||
pattern := '';
|
||||
|
||||
LOOP
|
||||
idx := idx + 1;
|
||||
cur_ch := substr(stmt, idx, 1);
|
||||
EXIT WHEN cur_ch IS NULL OR cur_ch = '';
|
||||
|
||||
CASE cur_ch
|
||||
WHEN '"' THEN
|
||||
IF quoted AND substr(stmt, idx + 1, 1) = '"' THEN
|
||||
idx := idx + 1;
|
||||
ELSE
|
||||
quoted := NOT quoted;
|
||||
END IF;
|
||||
IF quoted THEN
|
||||
CONTINUE;
|
||||
END IF;
|
||||
WHEN '(' THEN
|
||||
nested := nested + 1;
|
||||
CONTINUE;
|
||||
WHEN ')' THEN
|
||||
nested := nested - 1;
|
||||
IF nested < 0 THEN
|
||||
RAISE EXCEPTION 'syntax error at or near '')'' in ''%'' at position ''%''', stmt, idx;
|
||||
END IF;
|
||||
CONTINUE;
|
||||
WHEN ' ' THEN
|
||||
IF quoted OR nested > 0 THEN
|
||||
CONTINUE;
|
||||
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
|
||||
start := idx;
|
||||
CONTINUE;
|
||||
END IF;
|
||||
WHEN ';' THEN
|
||||
RAISE EXCEPTION 'syntax error at or near '';'' in ''%'' at position ''%''', stmt, idx;
|
||||
CONTINUE;
|
||||
ELSE
|
||||
pattern := pattern || upper(cur_ch);
|
||||
CONTINUE;
|
||||
END CASE;
|
||||
|
||||
-- END splitter code for testing
|
||||
|
||||
IF pattern = 'FROM' THEN
|
||||
from_cmd := substr(stmt, start + 1);
|
||||
proj_cmd := left(stmt, start - 1);
|
||||
stmt := from_cmd;
|
||||
nested := 0;
|
||||
quoted := FALSE;
|
||||
idx := idx - start;
|
||||
start := 1;
|
||||
RAISE NOTICE E'SELECT SPLITTING1\n%\n%\n%', stmt, proj_cmd, from_cmd;
|
||||
ELSIF pattern = 'DISTRIBUTE' THEN
|
||||
RAISE NOTICE E'SELECT SPLITTING2\n%\n%\n%', stmt, proj_cmd, from_cmd;
|
||||
CONTINUE;
|
||||
ELSIF pattern = 'DISTRIBUTEBY' THEN
|
||||
dist_cmd := substr(stmt, start + 1);
|
||||
from_cmd := left(stmt, start - 1);
|
||||
RAISE NOTICE E'SELECT SPLITTING3\n%\n%\n%\n%', stmt, proj_cmd, from_cmd, dist_cmd;
|
||||
EXIT;
|
||||
END IF;
|
||||
pattern := '';
|
||||
start := idx;
|
||||
END LOOP;
|
||||
END;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'multiple SELECT clauses in i_commands: ''%'' ''%''', proj_cmd, command_str;
|
||||
END IF;
|
||||
ELSIF left(pattern, 5) = 'FROM ' THEN
|
||||
IF from_cmd IS NULL THEN
|
||||
from_cmd := command_str;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'multiple FROM clauses in i_commands: ''%'' ''%''', from_cmd, command_str;
|
||||
END IF;
|
||||
ELSIF left(pattern, 14) = 'DISTRIBUTE BY ' THEN
|
||||
IF dist_cmd IS NULL THEN
|
||||
dist_cmd := command_str;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'multiple DISTRIBUTE BY clauses in i_commands: ''%'' ''%''', dist_cmd, command_str;
|
||||
END IF;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'unrecognized command in i_commands: ''%''', command_str;
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
IF proj_cmd IS NULL THEN
|
||||
-- minimum required input
|
||||
IF from_cmd IS NULL THEN
|
||||
RAISE EXCEPTION 'SELECT and FROM clauses are missing in i_commands';
|
||||
END IF;
|
||||
-- supply default projection
|
||||
proj_cmd := 'SELECT *';
|
||||
ELSE
|
||||
IF from_cmd IS NULL AND strpos(upper(proj_cmd), 'FROM ') = 0 THEN
|
||||
RAISE EXCEPTION 'FROM clause is missing in i_commands';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
IF dist_cmd IS NULL THEN
|
||||
dist_cmd := '';
|
||||
END IF;
|
||||
|
||||
IF i_vers IS NULL OR i_vers = '' THEN
|
||||
i_vers := s_vers_del || '1' || s_vers_sep || '0' || s_vers_sep || '0';
|
||||
ELSE
|
||||
i_vers := replace(i_vers, chr(2), s_vers_sep);
|
||||
IF LEFT(i_vers, 1) <> s_vers_del THEN
|
||||
i_vers := s_vers_del || i_vers;
|
||||
ELSIF char_length(i_vers ) < 2 THEN
|
||||
RAISE EXCEPTION 'illegal i_vers: ''%''', s_vers_del;
|
||||
END IF;
|
||||
IF strpos(substr(i_vers, 2), s_vers_del) > 0 THEN
|
||||
RAISE EXCEPTION 'i_vers may contain only one single, leading ''%'' character', s_vers_del
|
||||
USING HINT = 'specify snapshot version as [' || s_vers_del || ']x' || s_vers_sep || 'y' || s_vers_sep || 'z or ['
|
||||
|| s_vers_del || ']label with optional, leading ''' || s_vers_del || '''';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
IF char_length(i_name || i_vers) > 63 THEN
|
||||
RAISE EXCEPTION 'snapshot name too long: ''%''', i_name || i_vers;
|
||||
ELSE
|
||||
i_name := i_name || i_vers;
|
||||
END IF;
|
||||
|
||||
-- the final name of the snapshot
|
||||
qual_name := quote_ident(i_schema) || '.' || quote_ident(i_name);
|
||||
|
||||
-- check for duplicate snapshot
|
||||
IF 0 < (SELECT COUNT(*) FROM db4ai.snapshot WHERE schema = i_schema AND name = i_name) THEN
|
||||
RAISE EXCEPTION 'snapshot % already exists' , qual_name;
|
||||
END IF;
|
||||
|
||||
--SELECT nextval('db4ai.snapshot_sequence') INTO STRICT s_id;
|
||||
SELECT COALESCE(MAX(id)+1,0) FROM db4ai.snapshot INTO STRICT s_id; -- openGauss BUG: cannot create sequences in initdb
|
||||
|
||||
-- execute using current user privileges
|
||||
DECLARE
|
||||
e_message TEXT; -- exception message
|
||||
BEGIN
|
||||
EXECUTE 'CREATE TEMPORARY TABLE _db4ai_tmp_x' || s_id || ' AS ' || proj_cmd
|
||||
|| CASE WHEN from_cmd IS NULL THEN '' ELSE ' ' || from_cmd END;
|
||||
EXCEPTION WHEN undefined_table THEN
|
||||
GET STACKED DIAGNOSTICS e_message = MESSAGE_TEXT;
|
||||
|
||||
-- during function invocation, search path is redirected to {pg_temp, pg_catalog, function_schema} and becomes immutable
|
||||
RAISE INFO 'could not resolve relation % using system-defined "search_path" setting during function invocation: ''%''',
|
||||
substr(e_message, 10, 1 + strpos(substr(e_message,11), '" does not exist')),
|
||||
array_to_string(current_schemas(TRUE),', ')
|
||||
USING HINT = 'snapshots require schema-qualified table references, e.g. schema_name.table_name';
|
||||
RAISE;
|
||||
END;
|
||||
|
||||
-- extract normalized projection list
|
||||
i_commands := ARRAY[proj_cmd, from_cmd, dist_cmd, '', ''];
|
||||
SELECT string_agg(ident, ', '),
|
||||
string_agg(ident || ' AS f' || ordinal_position, ', '),
|
||||
string_agg('t' || s_id || '.f' || ordinal_position || ' AS ' || ident, ', ')
|
||||
FROM ( SELECT ordinal_position, quote_ident(column_name) AS ident
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = (SELECT nspname FROM pg_namespace WHERE oid=pg_my_temp_schema())
|
||||
AND table_name = '_db4ai_tmp_x' || s_id
|
||||
ORDER BY ordinal_position
|
||||
) INTO STRICT proj_cmd, i_commands[4], i_commands[5];
|
||||
IF proj_cmd IS NULL THEN
|
||||
RAISE EXCEPTION 'create snapshot internal error1: %', s_id;
|
||||
END IF;
|
||||
|
||||
-- finalize the snapshot using elevated privileges
|
||||
PERFORM db4ai.create_snapshot_internal(s_id, i_schema, i_name, i_commands, i_comment, CURRENT_USER);
|
||||
|
||||
-- drop temporary view used for privilege transfer
|
||||
EXECUTE 'DROP TABLE _db4ai_tmp_x' || s_id;
|
||||
|
||||
-- create custom view, owned by current user
|
||||
EXECUTE 'CREATE VIEW ' || qual_name || ' WITH(security_barrier) AS SELECT ' || proj_cmd || ' FROM db4ai.v' || s_id;
|
||||
EXECUTE 'COMMENT ON VIEW ' || qual_name || ' IS ''snapshot view backed by db4ai.v' || s_id
|
||||
|| CASE WHEN length(i_comment) > 0 THEN ' comment is "' || i_comment || '"' ELSE '' END || '''';
|
||||
EXECUTE 'ALTER VIEW ' || qual_name || ' OWNER TO ' || CURRENT_USER;
|
||||
|
||||
-- return final snapshot name
|
||||
res := ROW(i_schema, i_name);
|
||||
-- PG BUG: PG 9.2 cannot return composite type, only a reference to a variable of composite type
|
||||
return res;
|
||||
|
||||
END;
|
||||
$$;
|
||||
COMMENT ON FUNCTION db4ai.create_snapshot() IS 'Create a new snapshot';
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* -------------------------------------------------------------------------
|
||||
*
|
||||
* deploy.sql
|
||||
* Deploy DB4AI.Snapshot functionality.
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/snapshots/deploy.sql
|
||||
*
|
||||
* -------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
-- db local objects
|
||||
\set ON_ERROR_STOP 1
|
||||
|
||||
BEGIN;
|
||||
|
||||
\ir schema.sql
|
||||
\ir create.sql
|
||||
\ir prepare.sql
|
||||
\ir sample.sql
|
||||
\ir publish.sql
|
||||
\ir purge.sql
|
||||
|
||||
COMMIT;
|
|
@ -0,0 +1,913 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* -------------------------------------------------------------------------
|
||||
*
|
||||
* prepare.sql
|
||||
* Prepare DB4AI.Snapshot functionality.
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/snapshots/prepare.sql
|
||||
*
|
||||
* -------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.prepare_snapshot_internal(
|
||||
IN s_id BIGINT, -- snapshot id
|
||||
IN p_id BIGINT, -- parent id
|
||||
IN m_id BIGINT, -- matrix id
|
||||
IN r_id BIGINT, -- root id
|
||||
IN i_schema NAME, -- snapshot namespace
|
||||
IN i_name NAME, -- snapshot name
|
||||
IN i_commands TEXT[], -- DDL and DML commands defining snapshot modifications
|
||||
IN i_comment TEXT, -- snapshot description
|
||||
IN i_owner NAME, -- snapshot owner
|
||||
INOUT i_idx INT, -- index for exec_cmds
|
||||
INOUT i_exec_cmds TEXT[], -- DDL and DML for execution
|
||||
IN i_mapping NAME[] DEFAULT NULL -- mapping of user columns to backing column; generate rules if not NULL
|
||||
)
|
||||
RETURNS RECORD LANGUAGE plpgsql SECURITY DEFINER SET search_path = pg_catalog, pg_temp
|
||||
AS $$
|
||||
DECLARE
|
||||
command_str TEXT; -- command string for iterator
|
||||
e_stack_act TEXT; -- current stack for validation
|
||||
row_count BIGINT; -- number of rows in this snapshot
|
||||
BEGIN
|
||||
|
||||
BEGIN
|
||||
RAISE EXCEPTION 'SECURITY_STACK_CHECK';
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
GET STACKED DIAGNOSTICS e_stack_act = PG_EXCEPTION_CONTEXT;
|
||||
|
||||
IF CURRENT_SCHEMA = 'db4ai' THEN
|
||||
e_stack_act := replace(e_stack_act, ' prepare_snapshot(', ' db4ai.prepare_snapshot(');
|
||||
e_stack_act := replace(e_stack_act, ' prepare_snapshot_internal(', ' db4ai.prepare_snapshot_internal(');
|
||||
e_stack_act := replace(e_stack_act, ' sample_snapshot(', ' db4ai.sample_snapshot(');
|
||||
END IF;
|
||||
|
||||
IF e_stack_act LIKE E'referenced column: i_idx\n'
|
||||
'SQL statement "SELECT (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,\n'
|
||||
' CURRENT_USER, idx, exec_cmds)).i_idx"\n%'
|
||||
THEN
|
||||
e_stack_act := substr(e_stack_act, 200);
|
||||
END IF;
|
||||
|
||||
IF e_stack_act NOT SIMILAR TO 'PL/pgSQL function db4ai.prepare_snapshot\(name,name,text\[\],name,text\) line (175|541|607|714) at assignment%'
|
||||
AND e_stack_act NOT LIKE 'PL/pgSQL function db4ai.sample_snapshot(name,name,name[],numeric[],name[],text[]) line 215 at IF%'
|
||||
THEN
|
||||
RAISE EXCEPTION 'direct call to db4ai.prepare_snapshot_internal(bigint,bigint,bigint,bigint,name,name,text[],text,name,'
|
||||
'int,text[],name[]) is not allowed'
|
||||
USING HINT = 'call public interface db4ai.prepare_snapshot instead';
|
||||
END IF;
|
||||
END;
|
||||
|
||||
--generate rules from the mapping
|
||||
IF i_mapping IS NOT NULL THEN
|
||||
DECLARE
|
||||
sel_view TEXT := 'CREATE OR REPLACE VIEW db4ai.v' || s_id || ' WITH(security_barrier) AS SELECT ';
|
||||
ins_grnt TEXT := 'GRANT INSERT (';
|
||||
ins_rule TEXT := 'CREATE OR REPLACE RULE _INSERT AS ON INSERT TO db4ai.v' || s_id || ' DO INSTEAD INSERT INTO '
|
||||
'db4ai.t' || coalesce(m_id, s_id) || '(';
|
||||
ins_vals TEXT := ' VALUES (';
|
||||
upd_grnt TEXT;
|
||||
upd_rule TEXT;
|
||||
dist_key NAME[] := array_agg(coalesce(m[1], replace(m[2], '""', '"'))) FROM regexp_matches(
|
||||
getdistributekey('db4ai.t' || coalesce(m_id, p_id)),'([^\s",]+)|"((?:[^"]*"")*[^"]*)"', 'g') m;
|
||||
BEGIN
|
||||
|
||||
FOR idx IN 3 .. array_length(i_mapping, 1) BY 3 LOOP
|
||||
IF idx = 3 THEN
|
||||
ins_grnt := ins_grnt || quote_ident(i_mapping[idx]);
|
||||
ins_rule := ins_rule || coalesce(i_mapping[idx-2], i_mapping[idx-1]);
|
||||
ins_vals := ins_vals || 'new.' || quote_ident(i_mapping[idx]);
|
||||
ELSE
|
||||
sel_view := sel_view || ', ';
|
||||
ins_grnt := ins_grnt || ', ' || quote_ident(i_mapping[idx]);
|
||||
ins_rule := ins_rule || ', ' || coalesce(i_mapping[idx-2], i_mapping[idx-1]);
|
||||
ins_vals := ins_vals || ', ' || 'new.' || quote_ident(i_mapping[idx]);
|
||||
END IF;
|
||||
|
||||
IF i_mapping[idx-2] IS NULL THEN -- handle shared columns without private (only CSS)
|
||||
sel_view := sel_view || i_mapping[idx-1];
|
||||
ELSE
|
||||
IF i_mapping[idx-1] IS NULL THEN -- handle sole private column (all MSS and added CSS columns)
|
||||
sel_view := sel_view || i_mapping[idx-2];
|
||||
ELSE -- handle shadowing (CSS CASE)
|
||||
sel_view := sel_view || 'coalesce(' || i_mapping[idx-2] || ', ' || i_mapping[idx-1] || ')';
|
||||
END IF;
|
||||
IF dist_key IS NULL OR NOT i_mapping[idx-2] = ANY(dist_key) THEN -- no updates on DISTRIBUTE BY columns
|
||||
upd_grnt := CASE WHEN upd_grnt IS NULL -- grant update only on private column
|
||||
THEN 'GRANT UPDATE (' ELSE upd_grnt ||', ' END || quote_ident(i_mapping[idx]);
|
||||
upd_rule := CASE WHEN upd_rule IS NULL -- update only private column
|
||||
THEN 'CREATE OR REPLACE RULE _UPDATE AS ON UPDATE TO db4ai.v' || s_id || ' DO INSTEAD UPDATE db4ai.t'
|
||||
|| coalesce(m_id, s_id) || ' SET '
|
||||
ELSE upd_rule || ', ' END
|
||||
|| i_mapping[idx-2] || '=new.' || quote_ident(i_mapping[idx]); -- update private column
|
||||
END IF;
|
||||
END IF;
|
||||
sel_view := sel_view || ' AS ' || quote_ident(i_mapping[idx]);
|
||||
END LOOP;
|
||||
|
||||
i_exec_cmds := i_exec_cmds || ARRAY [
|
||||
[ 'O', sel_view || ', xc_node_id, ctid FROM db4ai.t' || coalesce(m_id, s_id)
|
||||
|| CASE WHEN m_id IS NULL THEN '' ELSE ' WHERE _' || s_id END ],
|
||||
[ 'O', 'GRANT SELECT, DELETE ON db4ai.v' || s_id || ' TO ' || i_owner ],
|
||||
[ 'O', ins_grnt || ') ON db4ai.v' || s_id || ' TO ' || i_owner ],
|
||||
[ 'O', ins_rule || CASE WHEN m_id IS NULL THEN ')' ELSE ', _' || s_id || ')' END || ins_vals
|
||||
|| CASE WHEN m_id IS NULL THEN ')' ELSE ', TRUE)' END ],
|
||||
[ 'O', 'CREATE OR REPLACE RULE _DELETE AS ON DELETE TO db4ai.v' || s_id || ' DO INSTEAD '
|
||||
|| CASE WHEN m_id IS NULL THEN 'DELETE FROM db4ai.t' || s_id ELSE 'UPDATE db4ai.t' || m_id || ' SET _' || s_id || '=FALSE' END
|
||||
|| ' WHERE t' || coalesce(m_id, s_id) || '.xc_node_id=old.xc_node_id AND t' || coalesce(m_id, s_id) || '.ctid=old.ctid' ] ];
|
||||
|
||||
IF upd_rule IS NOT NULL THEN
|
||||
i_exec_cmds := i_exec_cmds || ARRAY [
|
||||
[ 'O', upd_grnt || ') ON db4ai.v' || s_id || ' TO ' || i_owner ],
|
||||
[ 'O', upd_rule || ' WHERE t' || coalesce(m_id, s_id) || '.xc_node_id=old.xc_node_id AND t' || coalesce(m_id, s_id) || '.ctid=old.ctid' ]];
|
||||
END IF;
|
||||
|
||||
RETURN;
|
||||
END;
|
||||
END IF;
|
||||
|
||||
-- Execute the queries
|
||||
LOOP EXIT WHEN i_idx = 1 + array_length(i_exec_cmds, 1);
|
||||
CASE i_exec_cmds[i_idx][1]
|
||||
WHEN 'O' THEN
|
||||
-- RAISE NOTICE 'owner executing: %', i_exec_cmds[i_idx][2];
|
||||
EXECUTE i_exec_cmds[i_idx][2];
|
||||
i_idx := i_idx + 1;
|
||||
WHEN 'U' THEN
|
||||
RETURN;
|
||||
ELSE -- this should never happen
|
||||
RAISE EXCEPTION 'prepare snapshot internal error2: % %', idx, i_exec_cmds[idx];
|
||||
END CASE;
|
||||
END LOOP;
|
||||
|
||||
EXECUTE 'DROP RULE IF EXISTS _INSERT ON db4ai.v' || s_id;
|
||||
EXECUTE 'DROP RULE IF EXISTS _UPDATE ON db4ai.v' || s_id;
|
||||
EXECUTE 'DROP RULE IF EXISTS _DELETE ON db4ai.v' || s_id;
|
||||
EXECUTE 'COMMENT ON VIEW db4ai.v' || s_id || ' IS ''snapshot ' || quote_ident(i_schema) || '.' || quote_ident(i_name)
|
||||
|| ' backed by db4ai.t' || coalesce(m_id, s_id) || CASE WHEN length(i_comment) > 0 THEN ' comment is "' || i_comment
|
||||
|| '"' ELSE '' END || '''';
|
||||
EXECUTE 'REVOKE ALL PRIVILEGES ON db4ai.v' || s_id || ' FROM ' || i_owner;
|
||||
EXECUTE 'GRANT SELECT ON db4ai.v' || s_id || ' TO ' || i_owner || ' WITH GRANT OPTION';
|
||||
EXECUTE 'SELECT COUNT(*) FROM db4ai.v' || s_id INTO STRICT row_count;
|
||||
|
||||
INSERT INTO db4ai.snapshot (id, parent_id, matrix_id, root_id, schema, name, owner, commands, comment, row_count)
|
||||
VALUES (s_id, p_id, m_id, r_id, i_schema, i_name, i_owner, i_commands, i_comment, row_count);
|
||||
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.prepare_snapshot(
|
||||
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
|
||||
IN i_parent NAME, -- parent snapshot name
|
||||
IN i_commands TEXT[], -- DDL and DML commands defining snapshot modifications
|
||||
IN i_vers NAME DEFAULT NULL, -- override version postfix
|
||||
IN i_comment TEXT DEFAULT NULL -- description of this unit of data curation
|
||||
)
|
||||
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
|
||||
AS $$
|
||||
DECLARE
|
||||
s_id BIGINT; -- snapshot id
|
||||
r_id BIGINT; -- root id
|
||||
m_id BIGINT; -- matrix id
|
||||
p_id BIGINT; -- parent id
|
||||
c_id BIGINT; -- column id for backing table
|
||||
s_name NAME; -- current snapshot name
|
||||
s_mode VARCHAR(3); -- current snapshot mode
|
||||
s_vers_del CHAR; -- snapshot version delimiter, default '@'
|
||||
s_vers_sep CHAR; -- snapshot version separator, default '.'
|
||||
s_uv_proj TEXT; -- snapshot user view projection list
|
||||
p_name_vers TEXT[]; -- split full parent name into name and version
|
||||
p_sv_proj TEXT; -- parent snapshot system view projection list
|
||||
command_str TEXT; -- command string for iterator
|
||||
pattern TEXT; -- command pattern for matching
|
||||
ops_arr BOOLEAN[]; -- operation classes in i_commands
|
||||
ops_str TEXT[] := '{ALTER, INSERT, DELETE, UPDATE}'; -- operation classes as string
|
||||
ALTER_OP INT := 1; -- ALTER operation class
|
||||
INSERT_OP INT := 2; -- INSERT operation class
|
||||
DELETE_OP INT := 3; -- DELETE operation class
|
||||
UPDATE_OP INT := 4; -- UPDATE operation class
|
||||
vers_arr INT[]; -- split version digits
|
||||
exec_cmds TEXT[]; -- commands for execution
|
||||
qual_name TEXT; -- qualified snapshot name
|
||||
ALTER_CLAUSE INT := 1; -- ALTER clause class
|
||||
WHERE_CLAUSE INT := 2; -- WHERE clause class - for insert, update and delete
|
||||
FROM_CLAUSE INT := 3; -- FROM clause class - for insert, update and delete (as USING)
|
||||
SET_CLAUSE INT := 4; -- SET clause class - for updates and insert
|
||||
-- (as generic SQL: projection, list, VALUES, ...)
|
||||
AS_CLAUSE INT := 5; -- AS clause class - for delete and update
|
||||
-- (correlation name) - default is "snapshot"
|
||||
current_op INT; -- currently parsed operation class
|
||||
next_op INT; -- following operation class
|
||||
current_clauses TEXT[]; -- clauses for current operation
|
||||
next_clauses TEXT[]; -- clauses for next operation
|
||||
mapping NAME[]; -- mapping user column names to backing column names
|
||||
newmap BOOLEAN := FALSE; -- mapping has changed
|
||||
res db4ai.snapshot_name; -- composite result
|
||||
BEGIN
|
||||
|
||||
-- obtain active message level
|
||||
BEGIN
|
||||
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
|
||||
RAISE INFO 'effective client_min_messages is %', upper(current_setting('db4ai.message_level'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
END;
|
||||
|
||||
-- obtain active snapshot mode
|
||||
BEGIN
|
||||
s_mode := upper(current_setting('db4ai_snapshot_mode'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_mode := 'MSS';
|
||||
END;
|
||||
|
||||
IF s_mode NOT IN ('CSS', 'MSS') THEN
|
||||
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
|
||||
END IF;
|
||||
|
||||
-- obtain relevant configuration parameters
|
||||
BEGIN
|
||||
s_vers_del := upper(current_setting('db4ai_snapshot_version_delimiter'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_del := '@';
|
||||
END;
|
||||
BEGIN
|
||||
s_vers_sep := upper(current_setting('db4ai_snapshot_version_separator'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_sep := '.';
|
||||
END;
|
||||
|
||||
-- check all input parameters
|
||||
IF i_schema IS NULL OR i_schema = '' THEN
|
||||
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
|
||||
END IF;
|
||||
|
||||
IF i_parent IS NULL OR i_parent = '' THEN
|
||||
RAISE EXCEPTION 'i_parent cannot be NULL or empty';
|
||||
ELSE
|
||||
i_parent := replace(i_parent, chr(1), s_vers_del);
|
||||
i_parent := replace(i_parent, chr(2), s_vers_sep);
|
||||
p_name_vers := regexp_split_to_array(i_parent, s_vers_del);
|
||||
IF array_length(p_name_vers, 1) <> 2 OR array_length(p_name_vers, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_parent must contain exactly one ''%'' character', s_vers_del
|
||||
USING HINT = 'reference a snapshot using the format: snapshot_name' || s_vers_del || 'version';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
-- check if parent exists
|
||||
BEGIN
|
||||
SELECT id, matrix_id, root_id FROM db4ai.snapshot WHERE schema = i_schema AND name = i_parent INTO STRICT p_id, m_id, r_id;
|
||||
EXCEPTION WHEN NO_DATA_FOUND THEN
|
||||
RAISE EXCEPTION 'parent snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_parent);
|
||||
END;
|
||||
|
||||
--SELECT nextval('db4ai.snapshot_sequence') INTO STRICT s_id;
|
||||
SELECT MAX(id)+1 FROM db4ai.snapshot INTO STRICT s_id; -- openGauss BUG: cannot create sequences in initdb
|
||||
|
||||
-- extract highest used c_id from existing backing table or parent ()
|
||||
-- cannot use information_schema here, because the current user has no read permission on the backing table
|
||||
SELECT 1 + max(ltrim(attname, 'f')::BIGINT) FROM pg_catalog.pg_attribute INTO STRICT c_id
|
||||
WHERE attrelid = ('db4ai.t' || coalesce(m_id, p_id))::regclass AND attnum > 0 AND NOT attisdropped AND attname like 'f%';
|
||||
|
||||
IF c_id IS NULL THEN
|
||||
RAISE EXCEPTION 'prepare snapshot internal error3: %', coalesce(m_id, p_id);
|
||||
END IF;
|
||||
|
||||
IF i_commands IS NULL OR array_length(i_commands, 1) IS NULL OR array_length(i_commands, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_commands array malformed'
|
||||
USING HINT = 'pass SQL DML and DDL operations as TEXT[] literal, e.g. ''{ALTER, ADD a int, DROP c, DELETE, '
|
||||
'WHERE b=5, INSERT, FROM t, UPDATE, FROM t, SET x=y, SET z=f(z), WHERE t.u=v}''';
|
||||
END IF;
|
||||
|
||||
-- extract normalized projection list
|
||||
p_sv_proj := substring(pg_get_viewdef('db4ai.v' || p_id), '^SELECT (.*), t[0-9]+\.xc_node_id, t[0-9]+\.ctid FROM.*$');
|
||||
mapping := array(SELECT unnest(ARRAY[ m[1], m[2], coalesce(m[3], replace(m[4],'""','"'))])
|
||||
FROM regexp_matches(p_sv_proj, CASE s_mode WHEN 'CSS'
|
||||
-- inherited CSS columns are shared (private: nullable, shared: not null, user_cname: not null)
|
||||
THEN '(?:COALESCE\(t[0-9]+\.(f[0-9]+), )?t[0-9]+\.(f[0-9]+)(?:\))? AS (?:([^\s",]+)|"((?:[^"]*"")*[^"]*)")'
|
||||
-- all MSS columns are private (privte: not null, shared: nullable, user_cname: not null)
|
||||
ELSE '(?:COALESCE\()?t[0-9]+\.(f[0-9]+)(?:, t[0-9]+\.(f[0-9]+)\))? AS (?:([^\s",]+)|"((?:[^"]*"")*[^"]*)")'
|
||||
END, 'g') m);
|
||||
|
||||
-- In principle two MSS naming conventions are possible:
|
||||
-- (a) plain column names for MSS, allowing direct operations, but only with CompactSQL, not with TrueSQL. Conversion to CSS
|
||||
-- then needs to rename user columns (if they are in f[0-9]+) or simply add columns max fXX + 1
|
||||
-- (b) translated columns names for MSS and CSS. Simple MSS->CSS conversion. No direct operations, always using rewrite.
|
||||
-- This is the more general approach!
|
||||
|
||||
-- The need for rewriting using rules:
|
||||
-- UPDATE SET AS T SET T.x=y, "T.x"=y, “_$%_\\’”=NULL, T.z=DEFAULT, (a, b, T.c) = (SELECT 1 a, 2 b, 3 c)
|
||||
-- FROM H AS I, J as K
|
||||
-- WHERE _x_=5 AND _T.z_=5 AND v='T.z' AND (SELECT x, T.z FROM A as T)
|
||||
-- unqualified, ambiguous, quoted, string literals, ... in SET clause maybe still manageable but in WHERE
|
||||
-- no guarantee for correctness possible -> need to use system's SQL parser with rewrite rules! */
|
||||
|
||||
-- create / upgrade + prepare target snapshots for SQL DML/DDL operations
|
||||
IF s_mode = 'MSS' THEN
|
||||
DECLARE
|
||||
s_bt_proj TEXT; -- snapshot backing table projection list
|
||||
s_bt_dist TEXT; -- DISTRIBUTE BY clause for creating backing table
|
||||
BEGIN
|
||||
|
||||
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
|
||||
s_bt_proj := s_bt_proj || quote_ident(mapping[idx]) || ' AS ' || mapping[idx-2] || ',';
|
||||
END LOOP;
|
||||
|
||||
s_bt_dist := getdistributekey('db4ai.t' || coalesce(m_id, p_id));
|
||||
s_bt_dist := CASE WHEN s_bt_dist IS NULL
|
||||
THEN ' DISTRIBUTE BY REPLICATION'
|
||||
ELSE ' DISTRIBUTE BY HASH(' || s_bt_dist || ')' END; s_bt_dist := ''; -- we silently drop DISTRIBUTE_BY
|
||||
|
||||
exec_cmds := ARRAY [
|
||||
[ 'O', 'CREATE TABLE db4ai.t' || s_id || ' WITH (orientation = column, compression = low)'
|
||||
-- extract and propagate DISTRIBUTE BY from parent
|
||||
|| s_bt_dist || ' AS SELECT ' || rtrim(s_bt_proj, ',') || ' FROM db4ai.v' || p_id ]];
|
||||
END;
|
||||
ELSIF s_mode = 'CSS' THEN
|
||||
IF m_id IS NULL THEN
|
||||
exec_cmds := ARRAY [
|
||||
[ 'O', 'UPDATE db4ai.snapshot SET matrix_id = ' || p_id || ' WHERE schema = ''' || i_schema || ''' AND name = '''
|
||||
|| i_parent || '''' ],
|
||||
[ 'O', 'ALTER TABLE db4ai.t' || p_id || ' ADD _' || p_id || ' BOOLEAN NOT NULL DEFAULT TRUE' ],
|
||||
[ 'O', 'ALTER TABLE db4ai.t' || p_id || ' ALTER _' || p_id || ' SET DEFAULT FALSE' ],
|
||||
[ 'O', 'CREATE OR REPLACE VIEW db4ai.v' || p_id || ' WITH(security_barrier) AS SELECT ' || p_sv_proj || ', xc_node_id, ctid FROM db4ai.t'
|
||||
|| p_id || ' WHERE _' || p_id ]];
|
||||
m_id := p_id;
|
||||
END IF;
|
||||
exec_cmds := exec_cmds || ARRAY [
|
||||
[ 'O', 'ALTER TABLE db4ai.t' || m_id || ' ADD _' || s_id || ' BOOLEAN NOT NULL DEFAULT FALSE' ],
|
||||
[ 'O', 'UPDATE db4ai.t' || m_id || ' SET _' || s_id || ' = TRUE WHERE _' || p_id ]];
|
||||
END IF;
|
||||
|
||||
-- generate and append grant, create view and rewrite rules for new snapshot
|
||||
exec_cmds := (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,
|
||||
CURRENT_USER, NULL, exec_cmds, mapping)).i_exec_cmds;
|
||||
|
||||
FOREACH command_str IN ARRAY i_commands LOOP
|
||||
IF command_str IS NULL THEN
|
||||
RAISE EXCEPTION 'i_commands array contains NULL values';
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
-- apply SQL DML/DDL according to snapshot mode
|
||||
FOREACH command_str IN ARRAY (i_commands || ARRAY[NULL] ) LOOP
|
||||
command_str := btrim(command_str);
|
||||
pattern := upper(regexp_replace(left(command_str, 30), '\s+', ' ', 'g'));
|
||||
IF pattern is NULL THEN
|
||||
next_op := NULL;
|
||||
ELSIF pattern = 'ALTER' THEN -- ALTER keyword is optional
|
||||
next_op := ALTER_OP;
|
||||
ELSIF pattern = 'DELETE' THEN
|
||||
next_op := DELETE_OP;
|
||||
ELSIF pattern = 'INSERT' THEN
|
||||
next_op := INSERT_OP;
|
||||
ELSIF pattern = 'UPDATE' THEN
|
||||
next_op := UPDATE_OP;
|
||||
ELSIF left(pattern, 7) = 'DELETE ' THEN
|
||||
next_op := DELETE_OP;
|
||||
SELECT coalesce(m[1], m[2]), m [3] FROM regexp_matches(command_str,
|
||||
'^\s*DELETE\s+FROM\s*(?: snapshot |"snapshot")\s*(?:AS\s*(?: ([^\s"]+) |"((?:[^"]*"")*[^"]*)")\s*)?(.*)\s*$', 'i') m
|
||||
INTO next_clauses[AS_CLAUSE], next_clauses[FROM_CLAUSE];
|
||||
RAISE NOTICE E'XXX DELETE \n%\n%', command_str, array_to_string(next_clauses, E'\n');
|
||||
ELSIF left(pattern, 7) = 'INSERT ' THEN
|
||||
next_op := INSERT_OP;
|
||||
SELECT coalesce(m[1], m[2]), m [3] FROM regexp_matches(command_str,
|
||||
'^\s*INSERT\s+INTO\s*(?: snapshot |"snapshot")\s*(.*)\s*$', 'i') m
|
||||
INTO STRICT next_clauses[SET_CLAUSE];
|
||||
RAISE NOTICE E'XXX INSERT \n%\n%', command_str, array_to_string(next_clauses, E'\n');
|
||||
ELSIF left(pattern, 7) = 'UPDATE ' THEN
|
||||
next_op := UPDATE_OP;
|
||||
SELECT coalesce(m[1], m[2]), m [3] FROM regexp_matches(command_str,
|
||||
'^\s*UPDATE\s*(?: snapshot |"snapshot")\s*(?:AS\s*(?: ([^\s"]+) |"((?:[^"]*"")*[^"]*)")\s*)?(.*)\s*$', 'i') m
|
||||
INTO STRICT next_clauses[AS_CLAUSE], next_clauses[SET_CLAUSE];
|
||||
|
||||
DECLARE
|
||||
nested INT := 0; -- level of nesting
|
||||
quoted BOOLEAN := FALSE; -- inside quoted identifier
|
||||
cur_ch VARCHAR; -- current character in tokenizer
|
||||
idx INTEGER := 0; -- loop counter, cannot use FOR .. iterator
|
||||
start INTEGER := 1;
|
||||
stmt TEXT := next_clauses[SET_CLAUSE];
|
||||
BEGIN
|
||||
|
||||
-- BEGIN splitter code for testing
|
||||
|
||||
pattern := '';
|
||||
|
||||
LOOP
|
||||
idx := idx + 1;
|
||||
cur_ch := substr(stmt, idx, 1);
|
||||
EXIT WHEN cur_ch IS NULL OR cur_ch = '';
|
||||
|
||||
CASE cur_ch
|
||||
WHEN '"' THEN
|
||||
IF quoted AND substr(stmt, idx + 1, 1) = '"' THEN
|
||||
idx := idx + 1;
|
||||
ELSE
|
||||
quoted := NOT quoted;
|
||||
END IF;
|
||||
IF quoted THEN
|
||||
CONTINUE;
|
||||
END IF;
|
||||
WHEN '(' THEN
|
||||
nested := nested + 1;
|
||||
CONTINUE;
|
||||
WHEN ')' THEN
|
||||
nested := nested - 1;
|
||||
IF nested < 0 THEN
|
||||
RAISE EXCEPTION 'syntax error at or near '')'' in ''%'' at position ''%''', stmt, idx;
|
||||
END IF;
|
||||
CONTINUE;
|
||||
WHEN ' ' THEN
|
||||
IF quoted OR nested > 0 THEN
|
||||
CONTINUE;
|
||||
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
|
||||
start := idx;
|
||||
CONTINUE;
|
||||
END IF;
|
||||
ELSE
|
||||
pattern := pattern || upper(cur_ch);
|
||||
CONTINUE;
|
||||
END CASE;
|
||||
|
||||
-- END splitter code for testing
|
||||
|
||||
IF pattern IN ('FROM', 'WHERE') THEN
|
||||
next_clauses[FROM_CLAUSE] := substr(next_clauses[SET_CLAUSE], start + 1);
|
||||
next_clauses[SET_CLAUSE] := left(next_clauses[SET_CLAUSE], start - 1);
|
||||
EXIT;
|
||||
END IF;
|
||||
pattern := '';
|
||||
start := idx;
|
||||
END LOOP;
|
||||
END;
|
||||
RAISE NOTICE E'XXX UPDATE \n%\n%', command_str, array_to_string(next_clauses, E'\n');
|
||||
ELSIF left(pattern, 6) = 'ALTER ' THEN
|
||||
SELECT coalesce(m[1], m[2]), m [3] FROM regexp_matches(command_str,
|
||||
'^\s*ALTER\s+TABLE\s*(?: snapshot |"snapshot")\s*(.*)\s*$', 'i') m
|
||||
INTO STRICT next_clauses[ALTER_CLAUSE];
|
||||
RAISE NOTICE E'XXX ALTER \n%\n%', command_str, array_to_string(next_clauses, E'\n');
|
||||
IF current_op IS NULL OR current_clauses[ALTER_CLAUSE] IS NULL THEN
|
||||
next_op := ALTER_OP;
|
||||
ELSE
|
||||
current_clauses[ALTER_CLAUSE] := current_clauses[ALTER_CLAUSE] || ', ' || next_clauses[ALTER_CLAUSE];
|
||||
next_clauses[ALTER_CLAUSE] := NULL;
|
||||
END IF;
|
||||
ELSIF left(pattern, 4) = 'ADD ' OR left(pattern, 5) = 'DROP ' THEN
|
||||
--for chaining, conflicting ALTER ops must be avoided by user
|
||||
IF current_op IS NULL OR current_op <> ALTER_OP THEN
|
||||
next_op := ALTER_OP; -- ALTER keyword is optional
|
||||
next_clauses[ALTER_CLAUSE] := command_str;
|
||||
ELSIF current_clauses[ALTER_CLAUSE] IS NULL THEN
|
||||
current_clauses[ALTER_CLAUSE] := command_str;
|
||||
CONTINUE; -- allow chaining of ALTER ops
|
||||
ELSE
|
||||
current_clauses[ALTER_CLAUSE] := current_clauses[ALTER_CLAUSE] || ', ' || command_str;
|
||||
CONTINUE; -- allow chaining of ALTER ops
|
||||
END IF;
|
||||
ELSIF left(pattern, 6) = 'WHERE ' THEN
|
||||
IF current_op IS NULL THEN
|
||||
RAISE EXCEPTION 'missing INSERT / UPDATE / DELETE keyword before WHERE clause in i_commands at: ''%''', command_str;
|
||||
ELSIF current_op NOT IN (INSERT_OP, UPDATE_OP, DELETE_OP) THEN
|
||||
RAISE EXCEPTION 'illegal WHERE clause in % at: ''%''', ops_str[current_op], command_str;
|
||||
ELSIF current_clauses[WHERE_CLAUSE] IS NULL THEN
|
||||
current_clauses[WHERE_CLAUSE] := command_str;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'multiple WHERE clauses in % at: ''%''', ops_str[current_op], command_str;
|
||||
END IF;
|
||||
CONTINUE;
|
||||
ELSIF left(pattern, 5) = 'FROM ' THEN
|
||||
IF current_op IS NULL THEN
|
||||
RAISE EXCEPTION 'missing INSERT / UPDATE keyword before FROM clause in i_commands at: ''%''', command_str;
|
||||
ELSIF current_op NOT IN (INSERT_OP, UPDATE_OP) THEN
|
||||
RAISE EXCEPTION 'illegal FROM clause in % at: ''%''', ops_str[current_op], command_str;
|
||||
ELSIF current_clauses[FROM_CLAUSE] IS NULL THEN
|
||||
current_clauses[FROM_CLAUSE] := command_str;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'multiple FROM clauses in % at: ''%''', ops_str[current_op], command_str;
|
||||
END IF;
|
||||
CONTINUE;
|
||||
ELSIF left(pattern, 6) = 'USING ' THEN
|
||||
IF current_op IS NULL THEN
|
||||
RAISE EXCEPTION 'missing DELETE keyword before USING clause in i_commands at: ''%''', command_str;
|
||||
ELSIF current_op NOT IN (DELETE_OP) THEN
|
||||
RAISE EXCEPTION 'illegal USING clause in % at: ''%''', ops_str[current_op], command_str;
|
||||
ELSIF current_clauses[FROM_CLAUSE] IS NULL THEN
|
||||
current_clauses[FROM_CLAUSE] := command_str;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'multiple USING clauses in DELETE at: ''%''', command_str;
|
||||
END IF;
|
||||
CONTINUE;
|
||||
ELSIF left(pattern, 4) = 'SET ' THEN
|
||||
IF current_op IS NULL THEN
|
||||
RAISE EXCEPTION 'missing UPDATE keyword before SET clause in i_commands at: ''%''', command_str;
|
||||
ELSIF current_op NOT IN (UPDATE_OP) THEN
|
||||
RAISE EXCEPTION 'illegal SET clause in % at: ''%''', ops_str[current_op], command_str;
|
||||
--for chaining, conflicting assignments must be avoided by user
|
||||
ELSE -- allow chaining of SET
|
||||
current_clauses[SET_CLAUSE] := CASE WHEN current_clauses[SET_CLAUSE] IS NULL
|
||||
THEN command_str ELSE current_clauses[SET_CLAUSE] || ' ' || command_str END;
|
||||
END IF;
|
||||
CONTINUE;
|
||||
ELSIF left(pattern, 3) = 'AS ' THEN
|
||||
IF current_op IS NULL THEN
|
||||
RAISE EXCEPTION 'missing UPDATE / DELETE keyword before AS clause in i_commands at: ''%''', command_str;
|
||||
ELSIF current_op NOT IN (UPDATE_OP, DELETE_OP) THEN
|
||||
RAISE EXCEPTION 'illegal AS clause in % at: ''%''', ops_str[current_op], command_str;
|
||||
ELSIF current_clauses[AS_CLAUSE] IS NULL THEN
|
||||
DECLARE
|
||||
as_pos INT := 3 + strpos(upper(command_str), 'AS ');
|
||||
BEGIN
|
||||
current_clauses[AS_CLAUSE] := ltrim(substr(command_str, as_pos));
|
||||
END;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'multiple AS clauses in % at: ''%''', ops_str[current_op], command_str;
|
||||
END IF;
|
||||
CONTINUE;
|
||||
ELSE -- generic SQL allowed only in INSERT
|
||||
IF current_op IS NULL THEN
|
||||
RAISE EXCEPTION 'missing ALTER / INSERT / UPDATE / DELETE keyword before SQL clause: ''%''', command_str;
|
||||
ELSIF current_op NOT IN (INSERT_OP) THEN
|
||||
RAISE EXCEPTION 'illegal SQL clause in % at: ''%''', ops_str[current_op], command_str;
|
||||
ELSE -- allow chaining of generic SQL for PROJ LIST, VALUES, ...
|
||||
current_clauses[SET_CLAUSE] := CASE WHEN current_clauses[SET_CLAUSE] IS NULL
|
||||
THEN command_str ELSE current_clauses[SET_CLAUSE] || ' ' || command_str END;
|
||||
END IF;
|
||||
CONTINUE;
|
||||
END IF;
|
||||
|
||||
IF current_op IS NOT NULL THEN
|
||||
IF current_clauses IS NULL AND current_op NOT IN (DELETE_OP) THEN
|
||||
RAISE EXCEPTION 'missing auxiliary clauses in %',
|
||||
CASE WHEN command_str IS NULL THEN ops_str[current_op] ELSE ops_str[current_op] || ' before: '''
|
||||
|| command_str || '''' END;
|
||||
END IF;
|
||||
IF current_clauses[AS_CLAUSE] IS NULL THEN
|
||||
current_clauses[AS_CLAUSE] := 'snapshot';
|
||||
END IF;
|
||||
ops_arr[current_op] := TRUE;
|
||||
IF current_op = ALTER_OP THEN
|
||||
command_str := NULL; -- stores DDL statement for adding / dropping columns
|
||||
|
||||
DECLARE
|
||||
dropif TEXT := NULL; -- drop if exists
|
||||
expect BOOLEAN := TRUE; -- expect keyword
|
||||
alt_op VARCHAR(4); -- alter operation: NULL or 'ADD' or 'DROP'
|
||||
|
||||
quoted BOOLEAN := FALSE; -- inside quoted identifier
|
||||
cur_ch VARCHAR; -- current character in tokenizer
|
||||
idx INTEGER := 0; -- loop counter, cannot use FOR .. iterator
|
||||
tokens TEXT := current_clauses[ALTER_CLAUSE] || ',';
|
||||
BEGIN
|
||||
|
||||
-- BEGIN tokenizer code for testing
|
||||
|
||||
pattern := '';
|
||||
|
||||
LOOP
|
||||
idx := idx + 1;
|
||||
cur_ch := substr(tokens, idx, 1);
|
||||
EXIT WHEN cur_ch IS NULL OR cur_ch = '';
|
||||
|
||||
CASE cur_ch
|
||||
WHEN '"' THEN
|
||||
IF quoted AND substr(tokens, idx + 1, 1) = '"' THEN
|
||||
pattern := pattern || '"';
|
||||
idx := idx + 1;
|
||||
ELSE
|
||||
quoted := NOT quoted;
|
||||
END IF;
|
||||
IF quoted THEN
|
||||
CONTINUE;
|
||||
END IF;
|
||||
WHEN ',' THEN
|
||||
IF quoted THEN
|
||||
pattern := pattern || cur_ch;
|
||||
CONTINUE;
|
||||
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
|
||||
pattern := ',';
|
||||
ELSE
|
||||
idx := idx - 1; -- reset on comma for next loop
|
||||
END IF;
|
||||
WHEN ' ', E'\n', E'\t' THEN
|
||||
IF quoted THEN
|
||||
pattern := pattern || cur_ch;
|
||||
CONTINUE;
|
||||
ELSIF pattern IS NULL OR length(pattern) = 0 THEN
|
||||
CONTINUE;
|
||||
END IF;
|
||||
ELSE
|
||||
pattern := pattern || CASE WHEN quoted THEN cur_ch ELSE lower(cur_ch) END;
|
||||
CONTINUE;
|
||||
END CASE;
|
||||
|
||||
-- END tokenizer code for testing
|
||||
|
||||
IF alt_op = 'DROP' AND upper(dropif) = 'IF' THEN
|
||||
IF pattern = ',' THEN
|
||||
pattern := dropif; -- interpret 'if' as column name (not a keyword)
|
||||
idx := idx - 1; -- reset on comma for next loop
|
||||
ELSIF upper(pattern) <> 'EXISTS' THEN
|
||||
RAISE EXCEPTION 'expected EXISTS keyword in % operation after ''%'' in: ''%''',
|
||||
alt_op, dropif, current_clauses[ALTER_CLAUSE];
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
IF expect THEN
|
||||
IF upper(pattern) IN ('ADD', 'DROP') THEN
|
||||
IF alt_op IS NULL THEN
|
||||
alt_op := upper(pattern);
|
||||
expect := FALSE;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'unable to extract column name in % operation: ''%''',
|
||||
alt_op, current_clauses[ALTER_CLAUSE];
|
||||
END IF;
|
||||
ELSE
|
||||
RAISE EXCEPTION 'expected ADD or DROP keyword before ''%'' in: ''%''',
|
||||
pattern, current_clauses[ALTER_CLAUSE]
|
||||
USING HINT = 'currently only ADD and DROP supported';
|
||||
END IF;
|
||||
ELSIF pattern = ',' THEN
|
||||
expect := TRUE; -- allow chaining of ALTER ops
|
||||
ELSIF alt_op IS NULL THEN
|
||||
-- accept all possibly legal text following column name
|
||||
-- leave exact syntax check to SQL compiler
|
||||
IF command_str IS NOT NULL THEN
|
||||
command_str := command_str || ' ' || pattern;
|
||||
END IF;
|
||||
ELSIF upper(pattern) = 'COLUMN' THEN
|
||||
-- skip keyword COLUMN between ADD/DROP and column name
|
||||
ELSIF alt_op = 'DROP' AND upper(pattern) = 'IF' AND dropif IS NULL THEN
|
||||
dropif := pattern; -- 'IF' is not a keyword
|
||||
ELSIF alt_op = 'DROP' AND upper(pattern) = 'EXISTS' AND upper(dropif) = 'IF' THEN
|
||||
dropif := pattern; -- 'EXISTS' is not a keyword
|
||||
ELSIF alt_op IN ('ADD', 'DROP') THEN
|
||||
|
||||
-- attempt to map the pattern
|
||||
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
|
||||
IF pattern = mapping[idx] THEN
|
||||
IF alt_op = 'ADD' THEN
|
||||
-- check if pattern was mapped to an existing column
|
||||
RAISE EXCEPTION 'column "%" already exists in current snapshot', pattern;
|
||||
ELSIF alt_op = 'DROP' THEN
|
||||
-- DROP a private column (MSS and CSS)
|
||||
IF mapping[idx-2] IS NOT NULL THEN
|
||||
command_str := CASE WHEN command_str IS NULL
|
||||
THEN 'ALTER TABLE db4ai.t' || coalesce(m_id, s_id)
|
||||
ELSE command_str || ',' END || ' DROP ' || mapping[idx-2];
|
||||
END IF;
|
||||
mapping := mapping[1:(idx-3)] || mapping[idx+1:(array_length(mapping, 1))];
|
||||
newmap := TRUE;
|
||||
alt_op := NULL;
|
||||
EXIT;
|
||||
END IF;
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
-- apply the mapping
|
||||
IF alt_op = 'ADD' THEN
|
||||
-- ADD a private column (MSS and CSS)
|
||||
command_str := CASE WHEN command_str IS NULL
|
||||
THEN 'ALTER TABLE db4ai.t' || coalesce(m_id, s_id)
|
||||
ELSE command_str || ',' END || ' ADD f' || c_id;
|
||||
mapping := mapping || ARRAY [ 'f' || c_id, NULL, pattern ]::NAME[];
|
||||
newmap := TRUE;
|
||||
c_id := c_id + 1;
|
||||
ELSIF alt_op = 'DROP' THEN
|
||||
-- check whether pattern needs mapping to an existing column
|
||||
IF dropif IS NULL OR upper(dropif) <> 'EXISTS' THEN
|
||||
RAISE EXCEPTION 'unable to map field "%" to backing table in % operation: ''%''',
|
||||
pattern, alt_op, current_clauses[ALTER_CLAUSE];
|
||||
END IF;
|
||||
END IF;
|
||||
dropif := NULL;
|
||||
alt_op := NULL;
|
||||
ELSE
|
||||
-- checked before, this should never happen
|
||||
RAISE EXCEPTION 'unexpected ALTER clause: %', alt_op;
|
||||
END IF;
|
||||
|
||||
pattern := '';
|
||||
END LOOP;
|
||||
|
||||
IF quoted THEN
|
||||
RAISE EXCEPTION 'unterminated quoted identifier ''"%'' at or near: ''%''',
|
||||
substr(pattern, 1, char_length(pattern)-1), current_clauses[ALTER_CLAUSE];
|
||||
END IF;
|
||||
|
||||
-- CREATE OR REPLACE: cannot drop columns from view - MUST use DROP / CREATE
|
||||
-- clear view dependencies for backing table columns
|
||||
exec_cmds := exec_cmds || ARRAY [ 'O', 'DROP VIEW IF EXISTS db4ai.v' || s_id ];
|
||||
|
||||
-- append the DDL statement for the backing table (if any)
|
||||
IF command_str IS NOT NULL THEN
|
||||
exec_cmds := exec_cmds || ARRAY [ 'O', command_str ];
|
||||
END IF;
|
||||
|
||||
IF newmap THEN
|
||||
-- generate and append grant, create view and rewrite rules for new snapshot
|
||||
exec_cmds := (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,
|
||||
CURRENT_USER, NULL, exec_cmds, mapping)).i_exec_cmds;
|
||||
newmap := FALSE;
|
||||
END IF;
|
||||
END;
|
||||
ELSIF current_op = INSERT_OP THEN
|
||||
IF current_clauses[SET_CLAUSE] IS NULL THEN
|
||||
RAISE EXCEPTION 'missing SELECT or VALUES clause in INSERT operation';
|
||||
END IF;
|
||||
|
||||
exec_cmds := exec_cmds || ARRAY [
|
||||
'U', 'INSERT INTO db4ai.v' || s_id
|
||||
|| ' ' || current_clauses[SET_CLAUSE] -- generic SQL
|
||||
|| CASE WHEN current_clauses[FROM_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[FROM_CLAUSE] END
|
||||
|| CASE WHEN current_clauses[WHERE_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[WHERE_CLAUSE] END ];
|
||||
ELSIF current_op = DELETE_OP THEN
|
||||
exec_cmds := exec_cmds || ARRAY [
|
||||
'U', 'DELETE FROM db4ai.v' || s_id || ' AS ' || current_clauses[AS_CLAUSE]
|
||||
|| CASE WHEN current_clauses[FROM_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[FROM_CLAUSE] END -- USING
|
||||
|| CASE WHEN current_clauses[WHERE_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[WHERE_CLAUSE] END ];
|
||||
ELSIF current_op = UPDATE_OP THEN
|
||||
command_str := NULL; -- stores DDL statement for adding shadow columns
|
||||
|
||||
IF current_clauses[SET_CLAUSE] IS NULL THEN
|
||||
RAISE EXCEPTION 'missing SET clause in UPDATE operation';
|
||||
END IF;
|
||||
|
||||
-- extract updated fields and check their mapping
|
||||
FOR pattern IN
|
||||
SELECT coalesce(m[1], replace(m[2],'""','"'))
|
||||
FROM regexp_matches(current_clauses[SET_CLAUSE],
|
||||
'([^\s"]+)\s*=|"((?:[^"]*"")*[^"]*)"\s*=','g') m
|
||||
LOOP
|
||||
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
|
||||
IF pattern = mapping[idx] THEN
|
||||
-- ADD a private column (only CSS)
|
||||
IF mapping[idx-2] IS NULL THEN
|
||||
command_str := CASE WHEN command_str IS NULL
|
||||
THEN 'ALTER TABLE db4ai.t' || m_id
|
||||
ELSE command_str || ',' END
|
||||
|| ' ADD f' || c_id || ' '
|
||||
|| format_type(atttypid, atttypmod) FROM pg_catalog.pg_attribute
|
||||
WHERE attrelid = ('db4ai.t' || m_id)::regclass AND attname = mapping[idx-1];
|
||||
mapping[idx-2] := 'f' || c_id;
|
||||
newmap := TRUE;
|
||||
c_id := c_id + 1;
|
||||
END IF;
|
||||
pattern := NULL;
|
||||
EXIT;
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
-- check if pattern was mapped
|
||||
IF pattern IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'unable to map field "%" to backing table in UPDATE operation: %',
|
||||
pattern, current_clauses[SET_CLAUSE];
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
-- append the DDL statement for the backing table for adding shadow columns (if any)
|
||||
IF command_str IS NOT NULL THEN
|
||||
exec_cmds := exec_cmds || ARRAY [ 'O', command_str ];
|
||||
END IF;
|
||||
|
||||
IF newmap THEN
|
||||
-- generate and append grant, create view and rewrite rules for new snapshot
|
||||
exec_cmds := (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,
|
||||
CURRENT_USER, NULL, exec_cmds, mapping)).i_exec_cmds;
|
||||
newmap := FALSE;
|
||||
END IF;
|
||||
|
||||
exec_cmds := exec_cmds || ARRAY [
|
||||
'U', 'UPDATE db4ai.v' || s_id || ' AS ' || current_clauses[AS_CLAUSE]
|
||||
|| ' ' || current_clauses[SET_CLAUSE]
|
||||
|| CASE WHEN current_clauses[FROM_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[FROM_CLAUSE] END
|
||||
|| CASE WHEN current_clauses[WHERE_CLAUSE] IS NULL THEN '' ELSE ' ' || current_clauses[WHERE_CLAUSE] END ];
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
current_op := next_op;
|
||||
next_op := NULL;
|
||||
-- restore ALTER clause for ADD / DROP without 'ALTER' keyword, else reset to NULL
|
||||
current_clauses := next_clauses;
|
||||
next_clauses := NULL;
|
||||
END LOOP;
|
||||
|
||||
-- compute final version string
|
||||
IF i_vers IS NULL OR i_vers = '' THEN
|
||||
BEGIN
|
||||
vers_arr := regexp_split_to_array(p_name_vers[2], CASE s_vers_sep WHEN '.' THEN '\.' ELSE s_vers_sep END);
|
||||
|
||||
IF array_length(vers_arr, 1) <> 3 OR array_length(vers_arr, 2) IS NOT NULL OR
|
||||
vers_arr[1] ~ '[^0-9]' OR vers_arr[2] ~ '[^0-9]' OR vers_arr[3] ~ '[^0-9]' THEN
|
||||
RAISE EXCEPTION 'illegal version format';
|
||||
END IF;
|
||||
IF ops_arr[ALTER_OP] THEN
|
||||
vers_arr[1] := vers_arr[1] + 1;
|
||||
vers_arr[2] := 0;
|
||||
vers_arr[3] := 0;
|
||||
ELSIF ops_arr[INSERT_OP] OR ops_arr[DELETE_OP] THEN
|
||||
vers_arr[2] := vers_arr[2] + 1;
|
||||
vers_arr[3] := 0;
|
||||
ELSE
|
||||
vers_arr[3] := vers_arr[3] + 1;
|
||||
END IF;
|
||||
i_vers := s_vers_del || array_to_string(vers_arr, s_vers_sep);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE EXCEPTION 'parent has nonstandard version %. i_vers cannot be null or empty', p_name_vers[2]
|
||||
USING HINT = 'provide custom version using i_vers parameter for new snapshot';
|
||||
END;
|
||||
ELSE
|
||||
i_vers := replace(i_vers, chr(2), s_vers_sep);
|
||||
IF LEFT(i_vers, 1) <> s_vers_del THEN
|
||||
i_vers := s_vers_del || i_vers;
|
||||
ELSIF char_length(i_vers) < 2 THEN
|
||||
RAISE EXCEPTION 'illegal i_vers: ''%''', s_vers_del;
|
||||
END IF;
|
||||
IF strpos(substr(i_vers, 2), s_vers_del) > 0 THEN
|
||||
RAISE EXCEPTION 'i_vers may contain only one single, leading ''%'' character', s_vers_del
|
||||
USING HINT = 'specify snapshot version as [' || s_vers_del || ']x' || s_vers_sep || 'y' || s_vers_sep || 'z or ['
|
||||
|| s_vers_del || ']label with optional, leading ''' || s_vers_del || '''';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
IF char_length(p_name_vers[1] || i_vers) > 63 THEN
|
||||
RAISE EXCEPTION 'snapshot name too long: ''%''', p_name_vers[1] || i_vers;
|
||||
ELSE
|
||||
s_name := p_name_vers[1] || i_vers;
|
||||
END IF;
|
||||
|
||||
-- the final name of the snapshot
|
||||
qual_name := quote_ident(i_schema) || '.' || quote_ident(s_name);
|
||||
|
||||
-- check for duplicate snapshot
|
||||
IF 0 < (SELECT COUNT(0) FROM db4ai.snapshot WHERE schema = i_schema AND name = s_name) THEN
|
||||
RAISE EXCEPTION 'snapshot % already exists' , qual_name;
|
||||
END IF;
|
||||
|
||||
IF s_mode = 'MSS' THEN
|
||||
exec_cmds := exec_cmds || ARRAY [
|
||||
'O', 'COMMENT ON TABLE db4ai.t' || s_id || ' IS ''snapshot backing table, root is ' || qual_name || '''' ];
|
||||
END IF;
|
||||
|
||||
-- Execute the queries
|
||||
RAISE NOTICE E'accumulated commands:\n%', array_to_string(exec_cmds, E'\n');
|
||||
DECLARE
|
||||
idx INTEGER := 1; -- loop counter, cannot use FOR .. iterator
|
||||
BEGIN
|
||||
LOOP EXIT WHEN idx = 1 + array_length(exec_cmds, 1);
|
||||
WHILE exec_cmds[idx][1] = 'U' LOOP
|
||||
-- RAISE NOTICE 'user executing: %', exec_cmds[idx][2];
|
||||
DECLARE
|
||||
e_message TEXT; -- exception message
|
||||
BEGIN
|
||||
EXECUTE exec_cmds[idx][2];
|
||||
idx := idx + 1;
|
||||
EXCEPTION WHEN undefined_table THEN
|
||||
GET STACKED DIAGNOSTICS e_message = MESSAGE_TEXT;
|
||||
|
||||
-- during function invocation, search path is redirected to {pg_temp, pg_catalog, function_schema} and becomes immutable
|
||||
RAISE INFO 'could not resolve relation % using system-defined "search_path" setting during function invocation: ''%''',
|
||||
substr(e_message, 10, 1 + strpos(substr(e_message,11), '" does not exist')),
|
||||
array_to_string(current_schemas(TRUE),', ')
|
||||
USING HINT = 'snapshots require schema-qualified table references, e.g. schema_name.table_name';
|
||||
RAISE;
|
||||
END;
|
||||
END LOOP;
|
||||
|
||||
IF idx < array_length(exec_cmds, 1) AND (exec_cmds[idx][1] IS NULL OR exec_cmds[idx][1] <> 'O') THEN -- this should never happen
|
||||
RAISE EXCEPTION 'prepare snapshot internal error1: % %', idx, exec_cmds[idx];
|
||||
END IF;
|
||||
|
||||
-- execute owner statements (if any) and epilogue
|
||||
idx := (db4ai.prepare_snapshot_internal(s_id, p_id, m_id, r_id, i_schema, s_name, i_commands, i_comment,
|
||||
CURRENT_USER, idx, exec_cmds)).i_idx;
|
||||
END LOOP;
|
||||
END;
|
||||
|
||||
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
|
||||
s_uv_proj := s_uv_proj || quote_ident(mapping[idx]) || ',';
|
||||
END LOOP;
|
||||
-- create custom view, owned by current user
|
||||
EXECUTE 'CREATE VIEW ' || qual_name || ' WITH(security_barrier) AS SELECT '|| rtrim(s_uv_proj, ',') || ' FROM db4ai.v' || s_id;
|
||||
EXECUTE 'COMMENT ON VIEW ' || qual_name || ' IS ''snapshot view backed by db4ai.v' || s_id
|
||||
|| CASE WHEN length(i_comment) > 0 THEN ' comment is "' || i_comment || '"' ELSE '' END || '''';
|
||||
EXECUTE 'ALTER VIEW ' || qual_name || ' OWNER TO ' || CURRENT_USER;
|
||||
|
||||
-- return final snapshot name
|
||||
res := ROW(i_schema, s_name);
|
||||
return res;
|
||||
|
||||
END;
|
||||
$$;
|
||||
COMMENT ON FUNCTION db4ai.prepare_snapshot() IS 'Prepare snapshot from existing for data curation';
|
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* -------------------------------------------------------------------------
|
||||
*
|
||||
* publish.sql
|
||||
* Publish DB4AI.Snapshot functionality.
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/snapshots/publish.sql
|
||||
*
|
||||
* -------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.manage_snapshot_internal(
|
||||
IN i_schema NAME, -- snapshot namespace
|
||||
IN i_name NAME, -- snapshot name
|
||||
IN publish BOOLEAN -- publish or archive
|
||||
)
|
||||
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY DEFINER SET search_path = pg_catalog, pg_temp SET client_min_messages TO ERROR
|
||||
AS $$
|
||||
DECLARE
|
||||
s_mode VARCHAR(3); -- current snapshot mode
|
||||
s_vers_del CHAR; -- snapshot version delimiter, default '@'
|
||||
s_vers_sep CHAR; -- snapshot version separator, default '.'
|
||||
s_name_vers TEXT[]; -- split snapshot id into name and version
|
||||
e_stack_act TEXT; -- current stack for validation
|
||||
res db4ai.snapshot_name; -- composite result
|
||||
BEGIN
|
||||
|
||||
BEGIN
|
||||
RAISE EXCEPTION 'SECURITY_STACK_CHECK';
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
GET STACKED DIAGNOSTICS e_stack_act = PG_EXCEPTION_CONTEXT;
|
||||
|
||||
IF CURRENT_SCHEMA = 'db4ai' THEN
|
||||
e_stack_act := replace(e_stack_act, ' archive_snapshot(', ' db4ai.archive_snapshot(');
|
||||
e_stack_act := replace(e_stack_act, ' publish_snapshot(', ' db4ai.publish_snapshot(');
|
||||
END IF;
|
||||
|
||||
IF e_stack_act NOT SIMILAR TO '%PL/pgSQL function db4ai.(archive|publish)_snapshot\(name,name\) line 11 at assignment%'
|
||||
THEN
|
||||
RAISE EXCEPTION 'direct call to db4ai.manage_snapshot_internal(name,name,boolean) is not allowed'
|
||||
USING HINT = 'call public interface db4ai.(publish|archive)_snapshot instead';
|
||||
END IF;
|
||||
END;
|
||||
|
||||
-- obtain active message level
|
||||
BEGIN
|
||||
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
|
||||
RAISE INFO 'effective client_min_messages is ''%''', upper(current_setting('db4ai.message_level'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
END;
|
||||
|
||||
-- obtain relevant configuration parameters
|
||||
BEGIN
|
||||
s_mode := upper(current_setting('db4ai_snapshot_mode'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_mode := 'MSS';
|
||||
END;
|
||||
|
||||
IF s_mode NOT IN ('CSS', 'MSS') THEN
|
||||
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
|
||||
END IF;
|
||||
|
||||
-- obtain relevant configuration parameters
|
||||
BEGIN
|
||||
s_vers_del := current_setting('db4ai_snapshot_version_delimiter');
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_del := '@';
|
||||
END;
|
||||
BEGIN
|
||||
s_vers_sep := upper(current_setting('db4ai_snapshot_version_separator'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_sep := '.';
|
||||
END;
|
||||
|
||||
-- check all input parameters
|
||||
IF i_name IS NULL OR i_name = '' THEN
|
||||
RAISE EXCEPTION 'i_name cannot be NULL or empty';
|
||||
ELSE
|
||||
i_name := replace(i_name, chr(1), s_vers_del);
|
||||
i_name := replace(i_name, chr(2), s_vers_sep);
|
||||
s_name_vers := regexp_split_to_array(i_name, s_vers_del);
|
||||
IF array_length(s_name_vers, 1) <> 2 OR array_length(s_name_vers, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_name must contain exactly one ''%'' character', s_vers_del
|
||||
USING HINT = 'reference a snapshot using the format: snapshot_name' || s_vers_del || 'version';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
UPDATE db4ai.snapshot SET published = publish, archived = NOT publish WHERE schema = i_schema AND name = i_name;
|
||||
IF SQL%ROWCOUNT = 0 THEN
|
||||
RAISE EXCEPTION 'snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_name);
|
||||
END IF;
|
||||
|
||||
res := ROW(i_schema, i_name);
|
||||
return res;
|
||||
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.archive_snapshot(
|
||||
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER
|
||||
IN i_name NAME -- snapshot name
|
||||
)
|
||||
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
|
||||
AS $$
|
||||
DECLARE
|
||||
res db4ai.snapshot_name; -- composite result
|
||||
BEGIN
|
||||
|
||||
IF i_schema IS NULL OR i_schema = '' THEN
|
||||
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
|
||||
END IF;
|
||||
|
||||
-- return archived snapshot name
|
||||
res := db4ai.manage_snapshot_internal(i_schema, i_name, FALSE);
|
||||
return res;
|
||||
|
||||
END;
|
||||
$$;
|
||||
COMMENT ON FUNCTION db4ai.archive_snapshot() IS 'Archive snapshot for preventing usage in model training';
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.publish_snapshot(
|
||||
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
|
||||
IN i_name NAME -- snapshot name
|
||||
)
|
||||
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
|
||||
AS $$
|
||||
DECLARE
|
||||
res db4ai.snapshot_name; -- composite result
|
||||
BEGIN
|
||||
|
||||
IF i_schema IS NULL OR i_schema = '' THEN
|
||||
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
|
||||
END IF;
|
||||
|
||||
-- return published snapshot name
|
||||
res := db4ai.manage_snapshot_internal(i_schema, i_name, TRUE);
|
||||
return res;
|
||||
|
||||
END;
|
||||
$$;
|
||||
COMMENT ON FUNCTION db4ai.publish_snapshot() IS 'Publish snapshot for allowing usage in model training';
|
|
@ -0,0 +1,199 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* -------------------------------------------------------------------------
|
||||
*
|
||||
* purge.sql
|
||||
* Purge DB4AI.Snapshot functionality.
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/snapshots/purge.sql
|
||||
*
|
||||
* -------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.purge_snapshot_internal(
|
||||
IN i_schema NAME, -- snapshot namespace
|
||||
IN i_name NAME -- snapshot name
|
||||
)
|
||||
RETURNS VOID LANGUAGE plpgsql SECURITY DEFINER SET search_path = pg_catalog, pg_temp
|
||||
AS $$
|
||||
DECLARE
|
||||
s_id BIGINT; -- snapshot id
|
||||
p_id BIGINT; -- parent id
|
||||
m_id BIGINT; -- matrix id
|
||||
o_id BIGINT[]; -- other snapshot ids in same backing table
|
||||
pushed_cmds TEXT[]; -- commands to be pushed to descendants
|
||||
pushed_comment TEXT; -- comments to be pushed to descendants
|
||||
drop_cols NAME[]; -- orphaned columns
|
||||
e_stack_act TEXT; -- current stack for validation
|
||||
affected BIGINT; -- number of affected rows;
|
||||
BEGIN
|
||||
|
||||
BEGIN
|
||||
RAISE EXCEPTION 'SECURITY_STACK_CHECK';
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
GET STACKED DIAGNOSTICS e_stack_act = PG_EXCEPTION_CONTEXT;
|
||||
|
||||
IF CURRENT_SCHEMA = 'db4ai' THEN
|
||||
e_stack_act := replace(e_stack_act, 'ion pur', 'ion db4ai.pur');
|
||||
END IF;
|
||||
|
||||
IF e_stack_act NOT LIKE 'referenced column: purge_snapshot_internal
|
||||
SQL statement "SELECT db4ai.purge_snapshot_internal(i_schema, i_name)"
|
||||
PL/pgSQL function db4ai.purge_snapshot(name,name) line 62 at PERFORM%'
|
||||
THEN
|
||||
RAISE EXCEPTION 'direct call to db4ai.purge_snapshot_internal(name,name) is not allowed'
|
||||
USING HINT = 'call public interface db4ai.purge_snapshot instead';
|
||||
END IF;
|
||||
END;
|
||||
|
||||
-- check if snapshot exists
|
||||
BEGIN
|
||||
SELECT commands, comment, id, parent_id, matrix_id FROM db4ai.snapshot WHERE schema = i_schema AND name = i_name
|
||||
INTO STRICT pushed_cmds, pushed_comment, s_id, p_id, m_id;
|
||||
EXCEPTION WHEN NO_DATA_FOUND THEN
|
||||
RAISE EXCEPTION 'snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_name);
|
||||
END;
|
||||
|
||||
-- update descendants, if any
|
||||
UPDATE db4ai.snapshot SET
|
||||
parent_id = p_id,
|
||||
commands = pushed_cmds || commands,
|
||||
comment = CASE WHEN pushed_comment IS NULL THEN comment
|
||||
WHEN comment IS NULL THEN pushed_comment
|
||||
ELSE pushed_comment || ' | ' || comment END
|
||||
WHERE parent_id = s_id;
|
||||
IF p_id IS NULL AND SQL%ROWCOUNT > 0 THEN
|
||||
RAISE EXCEPTION 'cannot purge root snapshot ''%.%'' having dependent snapshots', quote_ident(i_schema), quote_ident(i_name)
|
||||
USING HINT = 'purge all dependent snapshots first';
|
||||
END IF;
|
||||
|
||||
IF m_id IS NULL THEN
|
||||
EXECUTE 'DROP VIEW db4ai.v' || s_id;
|
||||
EXECUTE 'DROP TABLE db4ai.t' || s_id;
|
||||
RAISE NOTICE 'PURGE_SNAPSHOT: MSS backing table dropped';
|
||||
ELSE
|
||||
SELECT array_agg(id) FROM db4ai.snapshot WHERE matrix_id = m_id AND id <> s_id INTO STRICT o_id;
|
||||
|
||||
IF o_id IS NULL OR array_length(o_id, 1) = 0 THEN
|
||||
EXECUTE 'DROP VIEW db4ai.v' || s_id;
|
||||
EXECUTE 'DROP TABLE db4ai.t' || m_id;
|
||||
RAISE NOTICE 'PURGE_SNAPSHOT: CSS backing table dropped';
|
||||
ELSE
|
||||
EXECUTE 'DELETE FROM db4ai.t' || m_id || ' WHERE _' || s_id || ' AND NOT (_' || array_to_string(o_id, ' OR _') || ')';
|
||||
GET DIAGNOSTICS affected = ROW_COUNT;
|
||||
|
||||
SELECT array_agg(quote_ident(column_name))
|
||||
FROM ( SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'db4ai' AND table_name = ANY ( ('{v' || array_to_string(s_id || o_id, ',v') || '}')::NAME[] )
|
||||
GROUP BY column_name
|
||||
HAVING SUM(CASE table_name WHEN 'v' || s_id THEN 0 ELSE 1 END) = 0 )
|
||||
INTO STRICT drop_cols;
|
||||
|
||||
EXECUTE 'DROP VIEW db4ai.v' || s_id;
|
||||
|
||||
IF TRUE OR drop_cols IS NULL THEN
|
||||
EXECUTE 'ALTER TABLE db4ai.t' || m_id || ' DROP _' || s_id;
|
||||
RAISE NOTICE 'PURGE_SNAPSHOT: orphaned rows dropped: %, orphaned columns dropped: none', affected;
|
||||
ELSE
|
||||
EXECUTE 'ALTER TABLE db4ai.t' || m_id || ' DROP _' || s_id || ', DROP ' || array_to_string(drop_cols, ', DROP ');
|
||||
RAISE NOTICE 'PURGE_SNAPSHOT: orphaned rows dropped: %, orphaned columns dropped: %', affected, drop_cols;
|
||||
END IF;
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
DELETE FROM db4ai.snapshot WHERE schema = i_schema AND name = i_name;
|
||||
IF SQL%ROWCOUNT = 0 THEN
|
||||
-- checked before, this should never happen
|
||||
RAISE INFO 'snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_name);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.purge_snapshot(
|
||||
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
|
||||
IN i_name NAME -- snapshot name
|
||||
)
|
||||
RETURNS db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
|
||||
AS $$
|
||||
DECLARE
|
||||
s_mode VARCHAR(3); -- current snapshot mode
|
||||
s_vers_del CHAR; -- snapshot version delimiter, default '@'
|
||||
s_vers_sep CHAR; -- snapshot version separator, default '.'
|
||||
s_name_vers TEXT[]; -- split full name into name and version
|
||||
res db4ai.snapshot_name; -- composite result
|
||||
BEGIN
|
||||
|
||||
-- obtain active message level
|
||||
BEGIN
|
||||
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
|
||||
RAISE INFO 'effective client_min_messages is ''%''', upper(current_setting('db4ai.message_level'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
END;
|
||||
|
||||
-- obtain active snapshot mode
|
||||
BEGIN
|
||||
s_mode := upper(current_setting('db4ai_snapshot_mode'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_mode := 'MSS';
|
||||
END;
|
||||
|
||||
IF s_mode NOT IN ('CSS', 'MSS') THEN
|
||||
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
|
||||
END IF;
|
||||
|
||||
-- obtain relevant configuration parameters
|
||||
BEGIN
|
||||
s_vers_del := current_setting('db4ai_snapshot_version_delimiter');
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_del := '@';
|
||||
END;
|
||||
BEGIN
|
||||
s_vers_sep := upper(current_setting('db4ai_snapshot_version_separator'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_sep := '.';
|
||||
END;
|
||||
|
||||
-- check all input parameters
|
||||
IF i_schema IS NULL OR i_schema = '' THEN
|
||||
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
|
||||
END IF;
|
||||
|
||||
IF i_name IS NULL OR i_name = '' THEN
|
||||
RAISE EXCEPTION 'i_name cannot be NULL or empty';
|
||||
ELSE
|
||||
i_name := replace(i_name, chr(1), s_vers_del);
|
||||
i_name := replace(i_name, chr(2), s_vers_sep);
|
||||
s_name_vers := regexp_split_to_array(i_name, s_vers_del);
|
||||
IF array_length(s_name_vers, 1) <> 2 OR array_length(s_name_vers, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_name must contain exactly one ''%'' character', s_vers_del
|
||||
USING HINT = 'reference a snapshot using the format: snapshot_name' || s_vers_del || 'version';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
BEGIN
|
||||
EXECUTE 'DROP VIEW ' || quote_ident(i_schema) || '.' || quote_ident(i_name);
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
END;
|
||||
|
||||
PERFORM db4ai.purge_snapshot_internal(i_schema, i_name);
|
||||
|
||||
-- return purged snapshot name
|
||||
res := ROW(i_schema, i_name);
|
||||
return res;
|
||||
|
||||
END;
|
||||
$$;
|
||||
COMMENT ON FUNCTION db4ai.purge_snapshot() IS 'Purge a snapshot and reclaim occupied storage';
|
|
@ -0,0 +1,269 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* -------------------------------------------------------------------------
|
||||
*
|
||||
* sample.sql
|
||||
* Sample DB4AI.Snapshot functionality.
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/snapshots/sample.sql
|
||||
*
|
||||
* -------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
CREATE OR REPLACE FUNCTION db4ai.sample_snapshot(
|
||||
IN i_schema NAME, -- snapshot namespace, default is CURRENT_USER or PUBLIC
|
||||
IN i_parent NAME, -- parent snapshot name
|
||||
IN i_sample_infixes NAME[], -- sample snapshot name infixes
|
||||
IN i_sample_ratios NUMBER[], -- size of each sample, as a ratio of the parent set
|
||||
IN i_stratify NAME[] DEFAULT NULL, -- stratification fields
|
||||
IN i_sample_comments TEXT[] DEFAULT NULL -- sample snapshot descriptions
|
||||
)
|
||||
RETURNS SETOF db4ai.snapshot_name LANGUAGE plpgsql SECURITY INVOKER SET client_min_messages TO ERROR
|
||||
AS $$
|
||||
DECLARE
|
||||
s_id BIGINT; -- snapshot id
|
||||
p_id BIGINT; -- parent id
|
||||
m_id BIGINT; -- matrix id
|
||||
r_id BIGINT; -- root id
|
||||
s_mode VARCHAR(3); -- current snapshot mode
|
||||
s_vers_del CHAR; -- snapshot version delimiter, default '@'
|
||||
s_vers_sep CHAR; -- snapshot version separator, default '.'
|
||||
s_sv_proj TEXT; -- snapshot system view projection list
|
||||
s_bt_proj TEXT; -- snapshot backing table projection list
|
||||
s_bt_dist TEXT; -- DISTRIBUTE BY clause for creating backing table
|
||||
s_uv_proj TEXT; -- snapshot user view projection list
|
||||
p_sv_proj TEXT; -- parent snapshot system view projection list
|
||||
p_name_vers TEXT[]; -- split full parent name into name and version
|
||||
stratify_count BIGINT[]; -- count per stratification class
|
||||
exec_cmds TEXT[]; -- commands for execution
|
||||
qual_name TEXT; -- qualified snapshot name
|
||||
mapping NAME[]; -- mapping user column names to backing column names
|
||||
s_name db4ai.snapshot_name; -- snapshot sample name
|
||||
BEGIN
|
||||
|
||||
-- obtain active message level
|
||||
BEGIN
|
||||
EXECUTE 'SET LOCAL client_min_messages TO ' || current_setting('db4ai.message_level');
|
||||
RAISE INFO 'effective client_min_messages is %', upper(current_setting('db4ai.message_level'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
END;
|
||||
|
||||
-- obtain active snapshot mode
|
||||
BEGIN
|
||||
s_mode := upper(current_setting('db4ai_snapshot_mode'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_mode := 'MSS';
|
||||
END;
|
||||
|
||||
IF s_mode NOT IN ('CSS', 'MSS') THEN
|
||||
RAISE EXCEPTION 'invalid snapshot mode: ''%''', s_mode;
|
||||
END IF;
|
||||
|
||||
-- obtain relevant configuration parameters
|
||||
BEGIN
|
||||
s_vers_del := current_setting('db4ai_snapshot_version_delimiter');
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_del := '@';
|
||||
END;
|
||||
BEGIN
|
||||
s_vers_sep := upper(current_setting('db4ai_snapshot_version_separator'));
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
s_vers_sep := '.';
|
||||
END;
|
||||
|
||||
-- check all input parameters
|
||||
IF i_schema IS NULL OR i_schema = '' THEN
|
||||
i_schema := CASE WHEN (SELECT 0=COUNT(*) FROM pg_catalog.pg_namespace WHERE nspname = CURRENT_USER) THEN 'public' ELSE CURRENT_USER END;
|
||||
END IF;
|
||||
|
||||
IF i_parent IS NULL OR i_parent = '' THEN
|
||||
RAISE EXCEPTION 'i_parent cannot be NULL or empty';
|
||||
ELSE
|
||||
i_parent := replace(i_parent, chr(1), s_vers_del);
|
||||
i_parent := replace(i_parent, chr(2), s_vers_sep);
|
||||
p_name_vers := regexp_split_to_array(i_parent, s_vers_del);
|
||||
IF array_length(p_name_vers, 1) <> 2 OR array_length(p_name_vers, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_parent must contain exactly one ''%'' character', s_vers_del
|
||||
USING HINT = 'reference a snapshot using the format: snapshot_name' || s_vers_del || 'version';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
-- check if parent exists
|
||||
BEGIN
|
||||
SELECT id, matrix_id, root_id FROM db4ai.snapshot WHERE schema = i_schema AND name = i_parent INTO STRICT p_id, m_id, r_id;
|
||||
EXCEPTION WHEN NO_DATA_FOUND THEN
|
||||
RAISE EXCEPTION 'parent snapshot %.% does not exist' , quote_ident(i_schema), quote_ident(i_parent);
|
||||
END;
|
||||
|
||||
IF i_sample_infixes IS NULL OR array_length(i_sample_infixes, 1) IS NULL OR array_length(i_sample_infixes, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_sample_infixes array malformed'
|
||||
USING HINT = 'pass sample infixes as NAME[] literal, e.g. ''{_train, _test}''';
|
||||
END IF;
|
||||
|
||||
IF i_sample_ratios IS NULL OR array_length(i_sample_ratios, 1) IS NULL OR array_length(i_sample_ratios, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_sample_ratios array malformed'
|
||||
USING HINT = 'pass sample percentages as NUMBER[] literal, e.g. ''{.8, .2}''';
|
||||
END IF;
|
||||
|
||||
IF array_length(i_sample_infixes, 1) <> array_length(i_sample_ratios, 1) THEN
|
||||
RAISE EXCEPTION 'i_sample_infixes and i_sample_ratios array length mismatch';
|
||||
END IF;
|
||||
|
||||
IF i_stratify IS NOT NULL THEN
|
||||
IF array_length(i_stratify, 1) IS NULL OR array_length(i_stratify, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_stratify array malformed'
|
||||
USING HINT = 'pass stratification field names as NAME[] literal, e.g. ''{color, size}''';
|
||||
END IF;
|
||||
|
||||
EXECUTE 'SELECT ARRAY[COUNT(DISTINCT ' || array_to_string(i_stratify, '), COUNT(DISTINCT ') || ')] FROM db4ai.v' || p_id
|
||||
INTO STRICT stratify_count;
|
||||
IF stratify_count IS NULL THEN
|
||||
RAISE EXCEPTION 'sample snapshot internal error2: %', p_id;
|
||||
END IF;
|
||||
|
||||
SELECT array_agg(ordered) FROM (SELECT unnest(i_stratify) ordered ORDER BY unnest(stratify_count)) INTO STRICT i_stratify;
|
||||
IF i_stratify IS NULL THEN
|
||||
RAISE EXCEPTION 'sample snapshot internal error3';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
IF i_sample_comments IS NOT NULL THEN
|
||||
IF array_length(i_sample_comments, 1) IS NULL OR array_length(i_sample_comments, 2) IS NOT NULL THEN
|
||||
RAISE EXCEPTION 'i_sample_comments array malformed'
|
||||
USING HINT = 'pass sample comments as TEXT[] literal, e.g. ''{comment 1, comment 2}''';
|
||||
ELSIF array_length(i_sample_infixes, 1) <> array_length(i_sample_comments, 1) THEN
|
||||
RAISE EXCEPTION 'i_sample_infixes and i_sample_comments array length mismatch';
|
||||
END IF;
|
||||
END IF;
|
||||
|
||||
-- extract normalized projection list (private: nullable, shared: not null, user_cname: not null)
|
||||
p_sv_proj := substring(pg_get_viewdef('db4ai.v' || p_id), '^SELECT (.*), t[0-9]+\.xc_node_id, t[0-9]+\.ctid FROM.*$');
|
||||
mapping := array(SELECT unnest(ARRAY[ m[1], m[2], coalesce(m[3], replace(m[4],'""','"'))]) FROM regexp_matches(p_sv_proj,
|
||||
'(?:COALESCE\(t[0-9]+\.(f[0-9]+), )?t[0-9]+\.(f[0-9]+)(?:\))? AS (?:([^\s",]+)|"((?:[^"]*"")*[^"]*)")', 'g') m);
|
||||
|
||||
FOR idx IN 3 .. array_length(mapping, 1) BY 3 LOOP
|
||||
IF s_mode = 'MSS' THEN
|
||||
s_sv_proj := s_sv_proj || coalesce(mapping[idx-2], mapping[idx-1]) || ' AS ' || quote_ident(mapping[idx]) || ',';
|
||||
s_bt_proj := s_bt_proj || quote_ident(mapping[idx]) || ' AS ' || coalesce(mapping[idx-2], mapping[idx-1]) || ',';
|
||||
ELSIF s_mode = 'CSS' THEN
|
||||
IF mapping[idx-2] IS NULL THEN
|
||||
s_sv_proj := s_sv_proj || mapping[idx-1] || ' AS ' || quote_ident(mapping[idx]) || ',';
|
||||
ELSE
|
||||
s_sv_proj := s_sv_proj || 'coalesce(' || mapping[idx-2] || ',' || mapping[idx-1] || ') AS ' || quote_ident(mapping[idx]) || ',';
|
||||
END IF;
|
||||
END IF;
|
||||
s_uv_proj := s_uv_proj || quote_ident(mapping[idx]) || ',';
|
||||
END LOOP;
|
||||
|
||||
s_bt_dist := getdistributekey('db4ai.t' || coalesce(m_id, p_id));
|
||||
s_bt_dist := CASE WHEN s_bt_dist IS NULL
|
||||
THEN ' DISTRIBUTE BY REPLICATION'
|
||||
ELSE ' DISTRIBUTE BY HASH(' || s_bt_dist || ')' END; s_bt_dist = '';
|
||||
|
||||
FOR i IN 1 .. array_length(i_sample_infixes, 1) LOOP
|
||||
IF i_sample_infixes[i] IS NULL THEN
|
||||
RAISE EXCEPTION 'i_sample_infixes array contains NULL values';
|
||||
END IF;
|
||||
|
||||
IF i_sample_ratios[i] IS NULL THEN
|
||||
RAISE EXCEPTION 'i_sample_ratios array contains NULL values';
|
||||
END IF;
|
||||
|
||||
qual_name := p_name_vers[1] || i_sample_infixes[i] || s_vers_del || p_name_vers[2];
|
||||
IF char_length(qual_name) > 63 THEN
|
||||
RAISE EXCEPTION 'sample snapshot name too long: ''%''', qual_name;
|
||||
ELSE
|
||||
s_name := (i_schema, qual_name);
|
||||
qual_name := quote_ident(s_name.schema) || '.' || quote_ident(s_name.name);
|
||||
END IF;
|
||||
|
||||
IF i_sample_ratios[i] < 0 OR i_sample_ratios[i] > 1 THEN
|
||||
RAISE EXCEPTION 'sample ratio must be between 0 and 1';
|
||||
END IF;
|
||||
|
||||
-- SELECT nextval('db4ai.snapshot_sequence') INTO STRICT s_id;
|
||||
SELECT MAX(id)+1 FROM db4ai.snapshot INTO STRICT s_id; -- openGauss BUG: cannot create sequences in initdb
|
||||
|
||||
-- check for duplicate snapshot
|
||||
IF 0 < (SELECT COUNT(*) FROM db4ai.snapshot WHERE schema = s_name.schema AND name = s_name.name) THEN
|
||||
RAISE EXCEPTION 'snapshot % already exists' , qual_name;
|
||||
END IF;
|
||||
|
||||
-- SET seed TO 0.444;
|
||||
-- setseed(0.444);
|
||||
-- dbms_random.seed(0.888);
|
||||
|
||||
-- create / upgrade + prepare target snapshots for SQL DML/DDL operations
|
||||
IF s_mode = 'MSS' THEN
|
||||
exec_cmds := ARRAY [
|
||||
-- extract and propagate DISTRIBUTE BY from root MSS snapshot
|
||||
[ 'O','CREATE TABLE db4ai.t' || s_id || ' WITH (orientation = column, compression = low)' || s_bt_dist
|
||||
|| ' AS SELECT ' || rtrim(s_bt_proj, ',') || ' FROM db4ai.v' || p_id || ' WHERE random() <= ' || i_sample_ratios[i] ],
|
||||
-- || ' AS SELECT ' || rtrim(s_bt_proj, ',') || ' FROM db4ai.v' || p_id || ' WHERE dbms_random.value(0, 1) <= ' || i_sample_ratios[i],
|
||||
[ 'O', 'COMMENT ON TABLE db4ai.t' || s_id || ' IS ''snapshot backing table, root is ' || qual_name || '''' ],
|
||||
[ 'O', 'CREATE VIEW db4ai.v' || s_id || ' WITH(security_barrier) AS SELECT ' || s_sv_proj || ' xc_node_id, ctid FROM db4ai.t' || s_id ]];
|
||||
ELSIF s_mode = 'CSS' THEN
|
||||
IF m_id IS NULL THEN
|
||||
exec_cmds := ARRAY [
|
||||
[ 'O', 'UPDATE db4ai.snapshot SET matrix_id = ' || p_id || ' WHERE schema = ''' || i_schema || ''' AND name = '''
|
||||
|| i_parent || '''' ],
|
||||
[ 'O', 'ALTER TABLE db4ai.t' || p_id || ' ADD _' || p_id || ' BOOLEAN NOT NULL DEFAULT TRUE' ],
|
||||
[ 'O', 'ALTER TABLE db4ai.t' || p_id || ' ALTER _' || p_id || ' SET DEFAULT FALSE' ],
|
||||
[ 'O', 'CREATE OR REPLACE VIEW db4ai.v' || p_id || ' WITH(security_barrier) AS SELECT ' || p_sv_proj || ', xc_node_id, ctid FROM db4ai.t'
|
||||
|| p_id || ' WHERE _' || p_id ]];
|
||||
m_id := p_id;
|
||||
END IF;
|
||||
exec_cmds := exec_cmds || ARRAY [
|
||||
[ 'O', 'ALTER TABLE db4ai.t' || m_id || ' ADD _' || s_id || ' BOOLEAN NOT NULL DEFAULT FALSE' ],
|
||||
[ 'O', 'UPDATE db4ai.t' || m_id || ' SET _' || s_id || ' = TRUE WHERE _' || p_id || ' AND random() <= '
|
||||
-- [ 'O', 'UPDATE db4ai.t' || m_id || ' SET _' || s_id || ' = TRUE WHERE _' || p_id || ' AND dbms_random.value(0, 1) <= '
|
||||
|| i_sample_ratios[i] ],
|
||||
[ 'O', 'CREATE VIEW db4ai.v' || s_id || ' WITH(security_barrier) AS SELECT ' || s_sv_proj || ' xc_node_id, ctid FROM db4ai.t' || m_id
|
||||
|| ' WHERE _' || s_id ]];
|
||||
END IF;
|
||||
|
||||
-- || ' AS SELECT ' || proj_list || ' FROM '
|
||||
-- || '(SELECT *, count(*) OVER() _cnt, row_number() OVER('
|
||||
-- || CASE WHEN i_stratify IS NOT NULL THEN 'ORDER BY ' || array_to_string(i_stratify, ', ') END
|
||||
-- || ') _row FROM db4ai.v' || p_id
|
||||
-- || ') WHERE round(_row/100 = 0
|
||||
--|| ' TABLESAMPLE SYSTEM ( ' || i_sample_ratios[i] || ') REPEATABLE (888)';
|
||||
|
||||
--SELECT * FROM (SELECT *, count(*) over()_ cnt, row_number() OVER(ORDER BY COLOR) _row FROM t) WHERE _row % (cnt/ 10) = 0;
|
||||
|
||||
-- Execute the queries
|
||||
RAISE NOTICE E'accumulated commands:\n%', array_to_string(exec_cmds, E'\n');
|
||||
IF 1 + array_length(exec_cmds, 1) <> (db4ai.prepare_snapshot_internal(
|
||||
s_id, p_id, m_id, r_id, s_name.schema, s_name.name,
|
||||
ARRAY [ 'SAMPLE ' || i_sample_infixes[i] || ' ' || i_sample_ratios[i] ||
|
||||
CASE WHEN i_stratify IS NULL THEN '' ELSE ' ' || i_stratify::TEXT END ],
|
||||
i_sample_comments[i], CURRENT_USER, 1, exec_cmds)).i_idx THEN
|
||||
RAISE EXCEPTION 'sample snapshot internal error1';
|
||||
END IF;
|
||||
|
||||
-- create custom view, owned by current user
|
||||
EXECUTE 'CREATE VIEW ' || qual_name || ' WITH(security_barrier) AS SELECT ' || rtrim(s_uv_proj, ',') || ' FROM db4ai.v' || s_id;
|
||||
EXECUTE 'COMMENT ON VIEW ' || qual_name || ' IS ''snapshot view backed by db4ai.v' || s_id
|
||||
|| CASE WHEN length(i_sample_comments[i]) > 0 THEN ' comment is "' || i_sample_comments[i] || '"' ELSE '' END || '''';
|
||||
EXECUTE 'ALTER VIEW ' || qual_name || ' OWNER TO ' || CURRENT_USER;
|
||||
|
||||
exec_cmds := NULL;
|
||||
|
||||
RETURN NEXT s_name;
|
||||
END LOOP;
|
||||
|
||||
END;
|
||||
$$;
|
||||
COMMENT ON FUNCTION db4ai.sample_snapshot() IS 'Create samples from a snapshot';
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* -------------------------------------------------------------------------
|
||||
*
|
||||
* schema.sql
|
||||
* Schema for DB4AI.Snapshot functionality.
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/dbmind/db4ai/snapshots/schema.sql
|
||||
*
|
||||
* -------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
GRANT USAGE ON SCHEMA db4ai TO PUBLIC;
|
||||
|
||||
--CREATE SEQUENCE db4ai.snapshot_sequence; -- openGauss BUG: cannot create sequences in initdb
|
||||
--GRANT USAGE ON SEQUENCE db4ai.snapshot_sequence TO db4ai;
|
||||
|
||||
CREATE TYPE db4ai.snapshot_name AS ("schema" NAME, "name" NAME); -- openGauss BUG: array type not created during initdb
|
||||
|
||||
CREATE TABLE IF NOT EXISTS db4ai.snapshot
|
||||
(
|
||||
id BIGINT UNIQUE, -- snapshot id (surrogate key)
|
||||
parent_id BIGINT, -- parent snapshot id (references snapshot.id)
|
||||
matrix_id BIGINT, -- matrix id from CSS snapshots, else NULL
|
||||
-- (references snapshot.id)
|
||||
root_id BIGINT, -- id of the initial snapshot, constructed via
|
||||
-- db4ai.create_snapshot() from operational data
|
||||
-- (references snapshot.id)
|
||||
schema NAME NOT NULL, -- schema where the snapshot view is exported
|
||||
name NAME NOT NULL, -- name of the snapshot, including version postfix
|
||||
owner NAME NOT NULL, -- name of the user who created this snapshot
|
||||
commands TEXT[] NOT NULL, -- complete list of SQL statements documenting how
|
||||
-- to generate this snapshot from its ancestor
|
||||
comment TEXT, -- description of the snapshot
|
||||
published BOOLEAN NOT NULL DEFAULT FALSE, -- TRUE, iff the snapshot is currently published
|
||||
archived BOOLEAN NOT NULL DEFAULT FALSE, -- TRUE, iff the snapshot is currently archived
|
||||
created TIMESTAMP DEFAULT CURRENT_TIMESTAMP,-- timestamp of snapshot creation date
|
||||
row_count BIGINT NOT NULL, -- number of rows in this snapshot
|
||||
PRIMARY KEY (schema, name)
|
||||
) /* DISTRIBUTE BY REPLICATION */;
|
||||
COMMENT ON TABLE db4ai.snapshot IS 'system catalog of meta-data on DB4AI snapshots';
|
||||
|
||||
COMMENT ON COLUMN db4ai.snapshot.id IS 'snapshot id (surrogate key)';
|
||||
COMMENT ON COLUMN db4ai.snapshot.parent_id IS 'parent snapshot id (references snapshot.id)';
|
||||
COMMENT ON COLUMN db4ai.snapshot.matrix_id IS E'matrix id from CSS snapshots, else NULL\n'
|
||||
'(references snapshot.id)';
|
||||
COMMENT ON COLUMN db4ai.snapshot.root_id IS E'id of the initial snapshot, constructed via\n'
|
||||
'db4ai.create_snapshot() from operational data\n'
|
||||
'(references snapshot.id)';
|
||||
COMMENT ON COLUMN db4ai.snapshot.schema IS 'schema where the snapshot view is exported';
|
||||
COMMENT ON COLUMN db4ai.snapshot.name IS 'name of the snapshot, including version postfix';
|
||||
COMMENT ON COLUMN db4ai.snapshot.owner IS 'name of the user who created this snapshot';
|
||||
COMMENT ON COLUMN db4ai.snapshot.commands IS E'complete list of SQL statements documenting how\n'
|
||||
'to generate this snapshot from its ancestor';
|
||||
COMMENT ON COLUMN db4ai.snapshot.comment IS 'description of the snapshot';
|
||||
COMMENT ON COLUMN db4ai.snapshot.published IS 'TRUE, iff the snapshot is currently published';
|
||||
COMMENT ON COLUMN db4ai.snapshot.archived IS 'TRUE, iff the snapshot is currently archived';
|
||||
COMMENT ON COLUMN db4ai.snapshot.created IS 'timestamp of snapshot creation date';
|
||||
|
||||
-- public read-only access to snapshot catalog
|
||||
REVOKE ALL PRIVILEGES ON db4ai.snapshot FROM PUBLIC;
|
||||
GRANT SELECT ON db4ai.snapshot TO PUBLIC;
|
|
@ -23,6 +23,7 @@
|
|||
#include "catalog/gs_obsscaninfo.h"
|
||||
#include "catalog/pg_obsscaninfo.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "db4ai/create_model.h"
|
||||
#include "commands/createas.h"
|
||||
#include "commands/defrem.h"
|
||||
#include "commands/prepare.h"
|
||||
|
@ -87,6 +88,7 @@ THR_LOCAL explain_get_index_name_hook_type explain_get_index_name_hook = NULL;
|
|||
|
||||
extern TrackDesc trackdesc[];
|
||||
extern sortMessage sortmessage[];
|
||||
extern DestReceiver* CreateDestReceiver(CommandDest dest);
|
||||
|
||||
/* Array to record plan table column names, type, etc */
|
||||
static const PlanTableEntry plan_table_cols[] = {{"id", PLANID, INT4OID, ANAL_OPT},
|
||||
|
@ -653,6 +655,22 @@ static void ExplainOneQuery(
|
|||
|
||||
return;
|
||||
}
|
||||
else if (IsA(query->utilityStmt, CreateModelStmt)) {
|
||||
CreateModelStmt* cm = (CreateModelStmt*) query->utilityStmt;
|
||||
|
||||
/*
|
||||
* Create the tuple receiver object and insert hyperp it will need
|
||||
*/
|
||||
DestReceiverTrainModel* dest_train_model = NULL;
|
||||
dest_train_model = (DestReceiverTrainModel*) CreateDestReceiver(DestTrainModel);
|
||||
configure_dest_receiver_train_model(dest_train_model, (AlgorithmML) cm->algorithm, cm->model, queryString);
|
||||
|
||||
PlannedStmt* plan = plan_create_model(cm, queryString, params, (DestReceiver*)dest_train_model);
|
||||
|
||||
ExplainOnePlan(plan, into, es, queryString, dest, params);
|
||||
return;
|
||||
|
||||
}
|
||||
|
||||
ExplainOneUtility(query->utilityStmt, into, es, queryString, params);
|
||||
return;
|
||||
|
@ -1928,6 +1946,7 @@ static void ExplainNode(
|
|||
case T_WorkTableScan:
|
||||
case T_ForeignScan:
|
||||
case T_VecForeignScan:
|
||||
case T_GradientDescentState:
|
||||
ExplainScanTarget((Scan*)plan, es);
|
||||
break;
|
||||
case T_ExtensiblePlan:
|
||||
|
|
|
@ -477,6 +477,12 @@ void GetPlanNodePlainText(
|
|||
*pname = "Vector Merge";
|
||||
*sname = *pt_operation = "Vector Merge Join";
|
||||
break;
|
||||
case T_GradientDescent:
|
||||
*pname = *sname = *pt_options = "Gradient Descent";
|
||||
break;
|
||||
case T_KMeans:
|
||||
*pname = *sname = *pt_options = "K-Means";
|
||||
break;
|
||||
default:
|
||||
*pname = *sname = *pt_operation = "?\?\?";
|
||||
break;
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "access/xact.h"
|
||||
#include "commands/copy.h"
|
||||
#include "commands/createas.h"
|
||||
#include "db4ai/create_model.h"
|
||||
#include "commands/matview.h"
|
||||
#include "executor/functions.h"
|
||||
#include "executor/spi.h"
|
||||
|
@ -150,6 +151,10 @@ DestReceiver* CreateDestReceiver(CommandDest dest)
|
|||
case DestBatchLocalRoundRobin:
|
||||
case DestBatchHybrid:
|
||||
return createStreamDestReceiver(dest);
|
||||
|
||||
case DestTrainModel:
|
||||
return CreateTrainModelDestReceiver();
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -188,6 +193,7 @@ void EndCommand(const char* commandTag, CommandDest dest)
|
|||
case DestCopyOut:
|
||||
case DestSQLFunction:
|
||||
case DestTransientRel:
|
||||
case DestTrainModel:
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -218,6 +224,7 @@ void EndCommand_noblock(const char* commandTag, CommandDest dest)
|
|||
case DestIntoRel:
|
||||
case DestCopyOut:
|
||||
case DestSQLFunction:
|
||||
case DestTrainModel:
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -265,6 +272,7 @@ void NullCommand(CommandDest dest)
|
|||
case DestCopyOut:
|
||||
case DestSQLFunction:
|
||||
case DestTransientRel:
|
||||
case DestTrainModel:
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -321,6 +329,7 @@ void ReadyForQuery(CommandDest dest)
|
|||
case DestCopyOut:
|
||||
case DestSQLFunction:
|
||||
case DestTransientRel:
|
||||
case DestTrainModel:
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -355,6 +364,7 @@ void ReadyForQuery_noblock(CommandDest dest, int timeout)
|
|||
case DestIntoRel:
|
||||
case DestCopyOut:
|
||||
case DestSQLFunction:
|
||||
case DestTrainModel:
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -102,6 +102,7 @@
|
|||
#include "gs_policy/gs_policy_audit.h"
|
||||
#include "gs_policy/policy_common.h"
|
||||
#include "client_logic/client_logic.h"
|
||||
#include "db4ai/create_model.h"
|
||||
#ifdef ENABLE_MULTIPLE_NODES
|
||||
#include "pgxc/pgFdwRemote.h"
|
||||
#include "pgxc/globalStatistic.h"
|
||||
|
@ -6649,6 +6650,13 @@ void standard_ProcessUtility(Node* parse_tree, const char* query_string, ParamLi
|
|||
(void)process_column_settings((CreateClientLogicColumn *)parse_tree);
|
||||
}
|
||||
break;
|
||||
|
||||
case T_CreateModelStmt:{ // DB4AI
|
||||
|
||||
exec_create_model((CreateModelStmt*) parse_tree, query_string, params, completion_tag);
|
||||
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_UNRECOGNIZED_NODE_TYPE),
|
||||
|
@ -7729,6 +7737,9 @@ const char* CreateCommandTag(Node* parse_tree)
|
|||
case OBJECT_CONVERSION:
|
||||
tag = "DROP CONVERSION";
|
||||
break;
|
||||
case OBJECT_DB4AI_MODEL:
|
||||
tag = "DROP MODEL";
|
||||
break;
|
||||
case OBJECT_SCHEMA:
|
||||
tag = "DROP SCHEMA";
|
||||
break;
|
||||
|
@ -8390,6 +8401,9 @@ const char* CreateCommandTag(Node* parse_tree)
|
|||
case T_ShutdownStmt:
|
||||
tag = "SHUTDOWN";
|
||||
break;
|
||||
case T_CreateModelStmt:
|
||||
tag = "CREATE MODEL";
|
||||
break;
|
||||
|
||||
default:
|
||||
elog(WARNING, "unrecognized node type: %d", (int)nodeTag(parse_tree));
|
||||
|
@ -9136,6 +9150,9 @@ LogStmtLevel GetCommandLogLevel(Node* parse_tree)
|
|||
case T_ShutdownStmt:
|
||||
lev = LOGSTMT_ALL;
|
||||
break;
|
||||
case T_CreateModelStmt: // DB4AI
|
||||
lev = LOGSTMT_ALL;
|
||||
break;
|
||||
|
||||
default:
|
||||
elog(WARNING, "unrecognized node type: %d", (int)nodeTag(parse_tree));
|
||||
|
|
|
@ -47,7 +47,7 @@ OBJS = execAmi.o execCurrent.o execGrouping.o execJunk.o execMain.o \
|
|||
nodeGroup.o nodeSubplan.o nodeSubqueryscan.o nodeTidscan.o \
|
||||
nodeForeignscan.o nodeWindowAgg.o tstoreReceiver.o spi.o \
|
||||
nodePartIterator.o nodeStub.o execClusterResize.o lightProxy.o execMerge.o \
|
||||
nodeExtensible.o opfusion.o opfusion_scan.o opfusion_util.o route.o
|
||||
nodeExtensible.o opfusion.o opfusion_scan.o opfusion_util.o route.o nodeGD.o nodeKMeans.o
|
||||
|
||||
override CPPFLAGS += -D__STDC_FORMAT_MACROS
|
||||
|
||||
|
|
|
@ -162,6 +162,8 @@
|
|||
#include "securec.h"
|
||||
#include "gstrace/gstrace_infra.h"
|
||||
#include "gstrace/executer_gstrace.h"
|
||||
#include "executor/nodeGD.h"
|
||||
#include "executor/nodeKMeans.h"
|
||||
|
||||
#define NODENAMELEN 64
|
||||
|
||||
|
@ -391,6 +393,10 @@ PlanState* ExecInitNodeByType(Plan* node, EState* estate, int eflags)
|
|||
return (PlanState*)ExecInitVecMergeJoin((VecMergeJoin*)node, estate, eflags);
|
||||
case T_VecWindowAgg:
|
||||
return (PlanState*)ExecInitVecWindowAgg((VecWindowAgg*)node, estate, eflags);
|
||||
case T_GradientDescent:
|
||||
return (PlanState*)ExecInitGradientDescent((GradientDescent*)node, estate, eflags);
|
||||
case T_KMeans:
|
||||
return (PlanState*)ExecInitKMeans((KMeans*)node, estate, eflags);
|
||||
default:
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_EXECUTOR),
|
||||
|
@ -692,7 +698,10 @@ TupleTableSlot* ExecProcNodeByType(PlanState* node)
|
|||
result = ExecStream((StreamState*)node);
|
||||
t_thrd.pgxc_cxt.GlobalNetInstr = NULL;
|
||||
return result;
|
||||
|
||||
case T_GradientDescentState:
|
||||
return ExecGradientDescent((GradientDescentState*)node);
|
||||
case T_KMeansState:
|
||||
return ExecKMeans((KMeansState*)node);
|
||||
default:
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_EXECUTOR),
|
||||
|
@ -1343,6 +1352,14 @@ static void ExecEndNodeByType(PlanState* node)
|
|||
ExecEndVecWindowAgg((VecWindowAggState*)node);
|
||||
break;
|
||||
|
||||
case T_GradientDescentState:
|
||||
ExecEndGradientDescent((GradientDescentState*)node);
|
||||
break;
|
||||
|
||||
case T_KMeansState:
|
||||
ExecEndKMeans((KMeansState*)node);
|
||||
break;
|
||||
|
||||
default:
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_EXECUTOR),
|
||||
|
|
|
@ -69,6 +69,7 @@
|
|||
#include "gstrace/gstrace_infra.h"
|
||||
#include "gstrace/executer_gstrace.h"
|
||||
#include "commands/trigger.h"
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
/* static function decls */
|
||||
static Datum ExecEvalArrayRef(ArrayRefExprState* astate, ExprContext* econtext, bool* isNull, ExprDoneCond* isDone);
|
||||
|
@ -147,6 +148,8 @@ static Datum ExecEvalGroupingIdExpr(
|
|||
GroupingIdExprState* gstate, ExprContext* econtext, bool* isNull, ExprDoneCond* isDone);
|
||||
static bool func_has_refcursor_args(Oid Funcid, FunctionCallInfoData* fcinfo);
|
||||
|
||||
extern Datum ExecEvalGradientDescent(GradientDescentExprState* mlstate, ExprContext* econtext, bool* isNull, ExprDoneCond* isDone);
|
||||
|
||||
THR_LOCAL PLpgSQL_execstate* plpgsql_estate = NULL;
|
||||
|
||||
/* ----------------------------------------------------------------
|
||||
|
@ -5412,6 +5415,20 @@ ExprState* ExecInitExpr(Expr* node, PlanState* parent)
|
|||
state = (ExprState*)rnstate;
|
||||
state->evalfunc = (ExprStateEvalFunc)ExecEvalRownum;
|
||||
} break;
|
||||
case T_GradientDescentExpr: {
|
||||
GradientDescentExprState* ml_state = (GradientDescentExprState*)makeNode(GradientDescentExprState);
|
||||
ml_state->ps = parent;
|
||||
ml_state->xpr = (GradientDescentExpr*)node;
|
||||
state = (ExprState*)ml_state;
|
||||
if (IsA(parent, GradientDescentState)) {
|
||||
state->evalfunc = (ExprStateEvalFunc)ExecEvalGradientDescent;
|
||||
} else {
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI),
|
||||
errcode(ERRCODE_INVALID_OPERATION),
|
||||
errmsg("unrecognized state %d for GradientDescentExpr", parent->type)));
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_UNRECOGNIZED_NODE_TYPE),
|
||||
|
|
|
@ -0,0 +1,457 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* nodeGD.cpp
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/runtime/executor/nodeGD.cpp
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
#include "executor/executor.h"
|
||||
#include "executor/nodeGD.h"
|
||||
#include "db4ai/gd.h"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
||||
GradientDescentHook_iteration gdhook_iteration = nullptr;
|
||||
|
||||
static bool transfer_slot(GradientDescentState* gd_state, TupleTableSlot* slot,
|
||||
int ith_tuple, Matrix* features, Matrix* dep_var)
|
||||
{
|
||||
const GradientDescent* gd_node = gd_get_node(gd_state);
|
||||
Assert(ith_tuple < (int)features->rows);
|
||||
|
||||
if (!slot->tts_isnull[gd_node->targetcol]) {
|
||||
int feature = 0;
|
||||
gd_float* w = features->data + ith_tuple * features->columns;
|
||||
for (int i = 0; i < get_natts(gd_state); i++) {
|
||||
if (i == gd_node->targetcol && !dep_var_is_continuous(gd_state->algorithm)) {
|
||||
if (!dep_var_is_binary(gd_state->algorithm)) {
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI),
|
||||
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("categorical dependent variable not implemented")));
|
||||
}
|
||||
|
||||
float dt = 0;
|
||||
if (get_atttypid(gd_state, gd_node->targetcol) == BOOLOID)
|
||||
dt = DatumGetBool(slot->tts_values[gd_node->targetcol])
|
||||
? gd_state->algorithm->max_class : gd_state->algorithm->min_class;
|
||||
else {
|
||||
bool found = false;
|
||||
for (int v = 0; v < gd_state->num_classes && !found; v++) {
|
||||
found = datumIsEqual(slot->tts_values[gd_node->targetcol], gd_state->binary_classes[v],
|
||||
get_attbyval(gd_state, gd_node->targetcol),
|
||||
get_attlen(gd_state, gd_node->targetcol));
|
||||
if (found)
|
||||
dt = (v == 1 ? gd_state->algorithm->max_class : gd_state->algorithm->min_class);
|
||||
}
|
||||
if (!found) {
|
||||
if (gd_state->num_classes == 2)
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI),
|
||||
errcode(ERRCODE_TOO_MANY_ARGUMENTS),
|
||||
errmsg("too many target values for binary operator")));
|
||||
|
||||
gd_state->binary_classes[gd_state->num_classes++] =
|
||||
datumCopy(slot->tts_values[gd_node->targetcol],
|
||||
get_attbyval(gd_state, gd_node->targetcol),
|
||||
get_attlen(gd_state, gd_node->targetcol));
|
||||
}
|
||||
}
|
||||
dep_var->data[ith_tuple] = dt;
|
||||
} else {
|
||||
gd_float value;
|
||||
if (slot->tts_isnull[i]) {
|
||||
Assert(i != gd_node->targetcol);
|
||||
value = 0.0; // default value for feature, it is not the target for sure
|
||||
}
|
||||
else
|
||||
value = gd_datum_get_float(get_atttypid(gd_state, i), slot->tts_values[i]);
|
||||
|
||||
if (i == gd_node->targetcol)
|
||||
dep_var->data[ith_tuple] = value;
|
||||
else {
|
||||
*w++ = value;
|
||||
feature++;
|
||||
}
|
||||
}
|
||||
}
|
||||
Assert(feature == gd_state->n_features-1);
|
||||
*w = 1.0; // bias
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void exec_gd_batch(GradientDescentState* gd_state, int iter)
|
||||
{
|
||||
// get information from the node
|
||||
const GradientDescent* gd_node = gd_get_node(gd_state);
|
||||
PlanState* outer_plan = outerPlanState(gd_state);
|
||||
Matrix* features;
|
||||
Matrix* dep_var;
|
||||
bool more = true;
|
||||
TupleTableSlot* slot = NULL;
|
||||
do {
|
||||
// read next batch
|
||||
features = gd_state->shuffle->get(gd_state->shuffle, &dep_var);
|
||||
|
||||
int ith_tuple = 0;
|
||||
while (more && ith_tuple < gd_node->batch_size) {
|
||||
slot = ExecProcNode(outer_plan);
|
||||
if (TupIsNull(slot)) {
|
||||
more = false;
|
||||
} else {
|
||||
if (transfer_slot(gd_state, slot, ith_tuple, features, dep_var)) {
|
||||
if (iter == 0)
|
||||
gd_state->processed++;
|
||||
|
||||
ith_tuple++;
|
||||
} else {
|
||||
if (iter == 0)
|
||||
gd_state->discarded++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// use the batch to test now in case the shuffle algorithm
|
||||
// releases it during unget
|
||||
if (iter > 0 && ith_tuple > 0) {
|
||||
if (ith_tuple < gd_node->batch_size) {
|
||||
matrix_resize(features, ith_tuple, gd_state->n_features);
|
||||
matrix_resize(dep_var, ith_tuple, 1);
|
||||
}
|
||||
|
||||
double loss = gd_state->algorithm->test_callback(gd_node, features, dep_var, &gd_state->weights, &gd_state->scores);
|
||||
gd_state->loss += loss;
|
||||
ereport(DEBUG1,
|
||||
(errmodule(MOD_DB4AI),
|
||||
errmsg("iteration %d loss = %.6f (total %.6g)", iter, loss, gd_state->loss)));
|
||||
|
||||
if (ith_tuple < gd_node->batch_size) {
|
||||
matrix_resize(features, gd_node->batch_size, gd_state->n_features);
|
||||
matrix_resize(dep_var, gd_node->batch_size, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// give back the batch to the shuffle algorithm
|
||||
gd_state->shuffle->unget(gd_state->shuffle, ith_tuple);
|
||||
} while (more);
|
||||
}
|
||||
|
||||
void exec_gd_start_iteration(GradientDescentState* gd_state)
|
||||
{
|
||||
if (gd_state->optimizer->start_iteration != nullptr)
|
||||
gd_state->optimizer->start_iteration(gd_state->optimizer);
|
||||
|
||||
if (gd_state->shuffle->start_iteration != nullptr)
|
||||
gd_state->shuffle->start_iteration(gd_state->shuffle);
|
||||
}
|
||||
|
||||
void exec_gd_end_iteration(GradientDescentState* gd_state)
|
||||
{
|
||||
if (gd_state->shuffle->end_iteration != nullptr)
|
||||
gd_state->shuffle->end_iteration(gd_state->shuffle);
|
||||
|
||||
if (gd_state->optimizer->end_iteration != nullptr)
|
||||
gd_state->optimizer->end_iteration(gd_state->optimizer);
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------
|
||||
* ExecGradientDescent
|
||||
* ----------------------------------------------------------------
|
||||
*
|
||||
* Training and test are interleaved to avoid a double scan over the data
|
||||
* for training and test. Iteration 0 only computes the initial weights, and
|
||||
* at each following iteration the model is tested with the current weights
|
||||
* and new weights are updated with the gradients. The optimization is clear:
|
||||
* for N iterations, the basic algorithm requires N*2 data scans, while the
|
||||
* interleaved train&test requires only N+1 data scans. When N=1 the number
|
||||
* of scans is the same (N*2 = N+1)
|
||||
*/
|
||||
TupleTableSlot* ExecGradientDescent(GradientDescentState* gd_state)
|
||||
{
|
||||
// check if training is already finished
|
||||
if (gd_state->done)
|
||||
return NULL;
|
||||
|
||||
// get information from the node
|
||||
const GradientDescent* gd_node = gd_get_node(gd_state);
|
||||
ScanDirection direction = gd_state->ss.ps.state->es_direction;
|
||||
PlanState* outer_plan = outerPlanState(gd_state);
|
||||
|
||||
// If backwards scan, just return NULL without changing state.
|
||||
if (!ScanDirectionIsForward(direction))
|
||||
return NULL;
|
||||
|
||||
// for counting execution time
|
||||
uint64_t start, finish, step;
|
||||
uint64_t iter_start, iter_finish;
|
||||
|
||||
// iterations
|
||||
double prev_loss = 0;
|
||||
TupleTableSlot* slot = NULL;
|
||||
|
||||
gd_state->processed = 0;
|
||||
gd_state->discarded = 0;
|
||||
|
||||
uint64_t max_usecs = ULLONG_MAX;
|
||||
if (gd_node->max_seconds > 0)
|
||||
max_usecs = gd_node->max_seconds * 1000000ULL;
|
||||
|
||||
bool stop = false;
|
||||
start = gd_get_clock_usecs();
|
||||
step = start;
|
||||
for (int iter = 0; !stop && iter <= gd_node->max_iterations; iter++) {
|
||||
iter_start = gd_get_clock_usecs();
|
||||
|
||||
// init loss & scores
|
||||
scores_init(&gd_state->scores);
|
||||
gd_state->loss = 0;
|
||||
|
||||
exec_gd_start_iteration(gd_state);
|
||||
exec_gd_batch(gd_state, iter);
|
||||
exec_gd_end_iteration(gd_state);
|
||||
|
||||
iter_finish = gd_get_clock_usecs();
|
||||
|
||||
// delta loss < loss tolerance?
|
||||
if (iter > 0)
|
||||
stop = (fabs(prev_loss - gd_state->loss) < gd_node->tolerance);
|
||||
|
||||
if (!stop) {
|
||||
// continue with another iteration with the new weights
|
||||
int bytes = sizeof(gd_float) * gd_state->n_features;
|
||||
int rc = memcpy_s(gd_state->weights.data, gd_state->weights.allocated * sizeof(gd_float),
|
||||
gd_state->optimizer->weights.data, bytes);
|
||||
securec_check(rc, "", "");
|
||||
|
||||
if (iter > 0) {
|
||||
gd_state->n_iterations++;
|
||||
if (gdhook_iteration != nullptr)
|
||||
gdhook_iteration(gd_state);
|
||||
}
|
||||
|
||||
// timeout || max_iterations
|
||||
stop = (gd_get_clock_usecs()-start >= max_usecs)
|
||||
|| (iter == gd_node->max_iterations);
|
||||
}
|
||||
|
||||
// trace at end or no more than once per second
|
||||
bool trace_iteration = gd_node->verbose || stop;
|
||||
if (!trace_iteration) {
|
||||
uint64_t now = gd_get_clock_usecs();
|
||||
uint64_t nusecs = now - step;
|
||||
if (nusecs > 1000000) {
|
||||
// more than one second
|
||||
trace_iteration = true;
|
||||
step = now;
|
||||
}
|
||||
}
|
||||
|
||||
if (iter>0 && trace_iteration) {
|
||||
gd_float* w = gd_state->weights.data;
|
||||
StringInfoData buf;
|
||||
initStringInfo(&buf);
|
||||
for (int i=0 ; i<gd_state->n_features ; i++)
|
||||
appendStringInfo(&buf, "%.3f,", w[i]);
|
||||
|
||||
ereport(DEBUG1,
|
||||
(errmodule(MOD_DB4AI),
|
||||
errmsg("ITERATION %d: test_loss=%.6f delta_loss=%.6f tolerance=%.3f accuracy=%.3f tuples=%d coef=%s",
|
||||
iter, gd_state->loss,
|
||||
fabs(prev_loss - gd_state->loss), gd_node->tolerance,
|
||||
get_accuracy(&gd_state->scores), gd_state->processed,
|
||||
buf.data)));
|
||||
pfree(buf.data);
|
||||
}
|
||||
|
||||
prev_loss = gd_state->loss;
|
||||
|
||||
if (!stop)
|
||||
ExecReScan(outer_plan); // for the next iteration
|
||||
}
|
||||
|
||||
finish = gd_get_clock_usecs();
|
||||
|
||||
gd_state->done = true;
|
||||
gd_state->usecs = finish - start;
|
||||
|
||||
// return trainined model
|
||||
ExprDoneCond isDone;
|
||||
slot = ExecProject(gd_state->ss.ps.ps_ProjInfo, &isDone);
|
||||
|
||||
return slot;
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------
|
||||
* ExecInitGradientDescent
|
||||
*
|
||||
* This initializes the GradientDescent node state structures and
|
||||
* the node's subplan.
|
||||
* ----------------------------------------------------------------
|
||||
*/
|
||||
GradientDescentState* ExecInitGradientDescent(GradientDescent* gd_node, EState* estate, int eflags)
|
||||
{
|
||||
GradientDescentState* gd_state = NULL;
|
||||
Plan* outer_plan = outerPlan(gd_node);
|
||||
|
||||
// check for unsupported flags
|
||||
Assert(!(eflags & (EXEC_FLAG_REWIND | EXEC_FLAG_BACKWARD | EXEC_FLAG_MARK)));
|
||||
|
||||
// create state structure
|
||||
gd_state = makeNode(GradientDescentState);
|
||||
gd_state->ss.ps.plan = (Plan*)gd_node;
|
||||
gd_state->ss.ps.state = estate;
|
||||
|
||||
// Tuple table initialization
|
||||
ExecInitScanTupleSlot(estate, &gd_state->ss);
|
||||
ExecInitResultTupleSlot(estate, &gd_state->ss.ps);
|
||||
|
||||
// initialize child expressions
|
||||
ExecAssignExprContext(estate, &gd_state->ss.ps);
|
||||
gd_state->ss.ps.targetlist = (List*)ExecInitExpr((Expr*)gd_node->plan.targetlist, (PlanState*)gd_state);
|
||||
|
||||
// initialize outer plan
|
||||
outerPlanState(gd_state) = ExecInitNode(outer_plan, estate, eflags);
|
||||
|
||||
// Initialize result tuple type and projection info.
|
||||
ExecAssignScanTypeFromOuterPlan(&gd_state->ss); // input tuples
|
||||
ExecAssignResultTypeFromTL(&gd_state->ss.ps); // result tuple
|
||||
ExecAssignProjectionInfo(&gd_state->ss.ps, NULL);
|
||||
gd_state->ss.ps.ps_TupFromTlist = false;
|
||||
|
||||
// select algorithm
|
||||
gd_state->algorithm = gd_get_algorithm(gd_node->algorithm);
|
||||
|
||||
// Input tuple initialization
|
||||
gd_state->tupdesc = ExecGetResultType(outerPlanState(gd_state));
|
||||
|
||||
int natts = gd_state->tupdesc->natts;
|
||||
gd_state->n_features = natts; // -1 dep_var, +1 bias (fixed as 1)
|
||||
|
||||
for (int i = 0; i < natts; i++) {
|
||||
Oid oidtype = gd_state->tupdesc->attrs[i]->atttypid;
|
||||
if (i == gd_node->targetcol) {
|
||||
switch (oidtype) {
|
||||
case BITOID:
|
||||
case VARBITOID:
|
||||
case BYTEAOID:
|
||||
case CHAROID:
|
||||
case RAWOID:
|
||||
case NAMEOID:
|
||||
case TEXTOID:
|
||||
case BPCHAROID:
|
||||
case VARCHAROID:
|
||||
case NVARCHAR2OID:
|
||||
case CSTRINGOID:
|
||||
case INT1OID:
|
||||
case INT2OID:
|
||||
case INT4OID:
|
||||
case INT8OID:
|
||||
case FLOAT4OID:
|
||||
case FLOAT8OID:
|
||||
case NUMERICOID:
|
||||
case ABSTIMEOID:
|
||||
case DATEOID:
|
||||
case TIMEOID:
|
||||
case TIMESTAMPOID:
|
||||
case TIMESTAMPTZOID:
|
||||
case TIMETZOID:
|
||||
case SMALLDATETIMEOID:
|
||||
// detect the different values while reading the data
|
||||
gd_state->num_classes = 0;
|
||||
break;
|
||||
|
||||
case BOOLOID:
|
||||
// values are known in advance
|
||||
gd_state->binary_classes[0] = BoolGetDatum(false);
|
||||
gd_state->binary_classes[1] = BoolGetDatum(true);
|
||||
gd_state->num_classes = 2;
|
||||
break;
|
||||
|
||||
default:
|
||||
// unsupported datatypes
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI),
|
||||
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Datatype of target not supported")));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// optimizer
|
||||
switch (gd_node->optimizer) {
|
||||
case OPTIMIZER_GD:
|
||||
gd_state->optimizer = gd_init_optimizer_gd(gd_state);
|
||||
break;
|
||||
case OPTIMIZER_NGD:
|
||||
gd_state->optimizer = gd_init_optimizer_ngd(gd_state);
|
||||
break;
|
||||
default:
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI),
|
||||
errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("Optimizer %d not supported", gd_node->optimizer)));
|
||||
break;
|
||||
}
|
||||
matrix_init(&gd_state->optimizer->weights, gd_state->n_features);
|
||||
matrix_init(&gd_state->optimizer->gradients, gd_state->n_features);
|
||||
|
||||
// shuffle
|
||||
gd_state->shuffle = gd_init_shuffle_cache(gd_state);
|
||||
gd_state->shuffle->optimizer = gd_state->optimizer;
|
||||
|
||||
// training state initialization
|
||||
gd_state->done = false;
|
||||
gd_state->learning_rate = gd_node->learning_rate;
|
||||
gd_state->n_iterations = 0;
|
||||
gd_state->loss = 0;
|
||||
matrix_init(&gd_state->weights, gd_state->n_features);
|
||||
|
||||
return gd_state;
|
||||
}
|
||||
|
||||
/* ----------------------------------------------------------------
|
||||
* ExecEndGradientDescent
|
||||
*
|
||||
* This shuts down the subplan and frees resources allocated
|
||||
* to this node.
|
||||
* ----------------------------------------------------------------
|
||||
*/
|
||||
void ExecEndGradientDescent(GradientDescentState* gd_state)
|
||||
{
|
||||
// release state
|
||||
matrix_release(&gd_state->weights);
|
||||
|
||||
gd_state->shuffle->release(gd_state->shuffle);
|
||||
|
||||
matrix_release(&gd_state->optimizer->gradients);
|
||||
matrix_release(&gd_state->optimizer->weights);
|
||||
gd_state->optimizer->release(gd_state->optimizer);
|
||||
|
||||
ExecFreeExprContext(&gd_state->ss.ps);
|
||||
ExecEndNode(outerPlanState(gd_state));
|
||||
pfree(gd_state);
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -60,6 +60,6 @@
|
|||
#endif
|
||||
|
||||
#define NAILED_IN_CATALOG_NUM 8
|
||||
#define CATALOG_NUM 91
|
||||
#define CATALOG_NUM 92
|
||||
|
||||
#endif
|
||||
|
|
|
@ -204,6 +204,7 @@ typedef enum ObjectClass {
|
|||
OCLASS_DIRECTORY, /* pg_directory */
|
||||
OCLASS_PG_JOB, /* pg_job */
|
||||
OCLASS_RLSPOLICY, /* pg_rlspolicy */
|
||||
OCLASS_DB4AI_MODEL, /* gs_model_warehouse */
|
||||
MAX_OCLASS /* MUST BE LAST */
|
||||
} ObjectClass;
|
||||
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* -------------------------------------------------------------------------
|
||||
*
|
||||
* gs_model.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/catalog/gs_model.h
|
||||
*
|
||||
* -------------------------------------------------------------------------
|
||||
*/
|
||||
#ifndef GS_MODEL_H
|
||||
#define GS_MODEL_H
|
||||
|
||||
#include "access/genam.h"
|
||||
#include "access/heapam.h"
|
||||
#include "access/sysattr.h"
|
||||
#include "catalog/genbki.h"
|
||||
#include "catalog/indexing.h"
|
||||
#include "utils/fmgroids.h"
|
||||
#include "utils/snapmgr.h"
|
||||
#include "utils/syscache.h"
|
||||
#include "utils/timestamp.h"
|
||||
#include "utils/date.h"
|
||||
|
||||
#define ModelRelationId 3991
|
||||
#define ModelRelation_Rowtype_Id 3994
|
||||
|
||||
#ifdef HAVE_INT64_TIMESTAMP
|
||||
#define timestamp int64
|
||||
#else
|
||||
#define timestamp double
|
||||
#endif
|
||||
|
||||
CATALOG(gs_model_warehouse,3991) BKI_ROWTYPE_OID(3994) BKI_SCHEMA_MACRO
|
||||
{
|
||||
NameData modelname; /* model name */
|
||||
Oid modelowner; /* model owner */
|
||||
timestamp createtime; /* Model storage time */
|
||||
int4 processedtuples;
|
||||
int4 discardedtuples;
|
||||
float4 pre_process_time;
|
||||
float4 exec_time;
|
||||
int4 iterations;
|
||||
Oid outputtype;
|
||||
#ifdef CATALOG_VARLEN /* variable-length fields start here */
|
||||
text modeltype; /* svm、kmeans、invalid */
|
||||
text query;
|
||||
bytea modeldata; /* OPTION just for ModelBinary*/
|
||||
float4 weight[1];
|
||||
text hyperparametersnames[1];
|
||||
text hyperparametersvalues[1];
|
||||
Oid hyperparametersoids[1];
|
||||
text coefnames[1];
|
||||
text coefvalues[1];
|
||||
Oid coefoids[1];
|
||||
text trainingscoresname[1];
|
||||
float4 trainingscoresvalue[1];
|
||||
text modeldescribe[1];
|
||||
#endif /* CATALOG_VARLEN */
|
||||
} FormData_gs_model_warehouse;
|
||||
|
||||
typedef FormData_gs_model_warehouse *Form_gs_model_warehouse;
|
||||
|
||||
#define Natts_gs_model_warehouse 22
|
||||
#define Anum_gs_model_model_name 1
|
||||
#define Anum_gs_model_owner_oid 2
|
||||
#define Anum_gs_model_create_time 3
|
||||
#define Anum_gs_model_processedTuples 4
|
||||
#define Anum_gs_model_discardedTuples 5
|
||||
#define Anum_gs_model_process_time_secs 6
|
||||
#define Anum_gs_model_exec_time_secs 7
|
||||
#define Anum_gs_model_iterations 8
|
||||
#define Anum_gs_model_outputType 9
|
||||
|
||||
#define Anum_gs_model_model_type 10
|
||||
#define Anum_gs_model_query 11
|
||||
#define Anum_gs_model_modelData 12
|
||||
#define Anum_gs_model_weight 13
|
||||
#define Anum_gs_model_hyperparametersNames 14
|
||||
#define Anum_gs_model_hyperparametersValues 15
|
||||
#define Anum_gs_model_hyperparametersOids 16
|
||||
#define Anum_gs_model_coefNames 17
|
||||
#define Anum_gs_model_coefValues 18
|
||||
#define Anum_gs_model_coefOids 19
|
||||
#define Anum_gs_model_trainingScoresName 20
|
||||
#define Anum_gs_model_trainingScoresValue 21
|
||||
#define Anum_gs_model_modeldescribe 22
|
||||
|
||||
|
||||
|
||||
// Locate the oid for a given model name
|
||||
inline Oid get_model_oid(const char* model_name, bool missing_ok)
|
||||
{
|
||||
Oid oid;
|
||||
|
||||
oid = GetSysCacheOid1(DB4AI_MODEL, CStringGetDatum(model_name));
|
||||
if (!OidIsValid(oid) && !missing_ok) {
|
||||
ereport(
|
||||
ERROR, (errcode(ERRCODE_UNDEFINED_OBJECT), errmsg("model \"%s\" does not exist", model_name)));
|
||||
}
|
||||
return oid;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Remove a model by oid
|
||||
inline void remove_model_by_oid(Oid model_oid)
|
||||
{
|
||||
Relation rel;
|
||||
HeapTuple tup;
|
||||
ScanKeyData skey[1];
|
||||
SysScanDesc scan;
|
||||
|
||||
ScanKeyInit(&skey[0], ObjectIdAttributeNumber, BTEqualStrategyNumber, F_OIDEQ, ObjectIdGetDatum(model_oid));
|
||||
|
||||
rel = heap_open(ModelRelationId, RowExclusiveLock);
|
||||
|
||||
scan = systable_beginscan(rel, GsModelOidIndexId, true, SnapshotNow, 1, skey);
|
||||
|
||||
/* we expect exactly one match */
|
||||
tup = systable_getnext(scan);
|
||||
if (!HeapTupleIsValid(tup))
|
||||
ereport(
|
||||
ERROR, (errcode(ERRCODE_UNDEFINED_OBJECT), errmsg("could not find tuple for model entry %u", model_oid)));
|
||||
|
||||
simple_heap_delete(rel, &tup->t_self);
|
||||
|
||||
systable_endscan(scan);
|
||||
heap_close(rel, RowExclusiveLock);
|
||||
}
|
||||
|
||||
#endif
|
|
@ -547,6 +547,11 @@ DECLARE_UNIQUE_INDEX(gs_matviewdep_oid_index, 9992, on gs_matview_dependency usi
|
|||
DECLARE_UNIQUE_INDEX(gs_opt_model_name_index, 9997, on gs_opt_model using btree(model_name name_ops));
|
||||
#define GsOPTModelNameIndexId 9997
|
||||
|
||||
DECLARE_UNIQUE_INDEX(gs_model_oid_index, 3992, on gs_model_warehouse using btree(oid oid_ops));
|
||||
#define GsModelOidIndexId 3992
|
||||
DECLARE_UNIQUE_INDEX(gs_model_name_index, 3993, on gs_model_warehouse using btree(modelname name_ops));
|
||||
#define GsModelNameIndexId 3993
|
||||
|
||||
/* last step of initialization script: build the indexes declared above */
|
||||
BUILD_INDICES
|
||||
|
||||
|
|
|
@ -98,6 +98,9 @@ DATA(insert OID = 4989 ( "snapshot" PGUID 0 _null_ n));
|
|||
DESCR("snapshot schema");
|
||||
#define PG_SNAPSHOT_NAMESPACE 4989
|
||||
|
||||
DATA(insert OID = 4991 ( "db4ai" PGUID 0 _null_ n));
|
||||
DESCR("db4ai schema");
|
||||
#define PG_DB4AI_NAMESPACE 4991
|
||||
/*
|
||||
* prototypes for functions in pg_namespace.c
|
||||
*/
|
||||
|
|
|
@ -398,6 +398,13 @@ typedef FormData_pg_proc *Form_pg_proc;
|
|||
#define PGCHECKAUTHIDFUNCOID 3228
|
||||
#define JSONAGGFUNCOID 5206
|
||||
#define JSONOBJECTAGGFUNCOID 5209
|
||||
#define DB4AI_PREDICT_BY_BOOL_OID 7101
|
||||
#define DB4AI_PREDICT_BY_INT32_OID 7102
|
||||
#define DB4AI_PREDICT_BY_INT64_OID 7103
|
||||
#define DB4AI_PREDICT_BY_FLOAT4_OID 7105
|
||||
#define DB4AI_PREDICT_BY_FLOAT8_OID 7106
|
||||
#define DB4AI_PREDICT_BY_NUMERIC_OID 7107
|
||||
#define DB4AI_PREDICT_BY_TEXT_OID 7108
|
||||
|
||||
/*
|
||||
* Symbolic values for prokind column
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#define PgxcNodeRelationId 9015
|
||||
#define PgxcNodeRelation_Rowtype_Id 11649
|
||||
|
||||
CATALOG(pgxc_node,9015) BKI_SHARED_RELATION BKI_SCHEMA_MACRO
|
||||
CATALOG(pgxc_node,9015) BKI_ROWTYPE_OID(11649) BKI_SHARED_RELATION BKI_SCHEMA_MACRO
|
||||
{
|
||||
NameData node_name;
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ DECLARE_TOAST(pg_trigger, 2336, 2337);
|
|||
DECLARE_TOAST(pg_partition, 5502, 5503);
|
||||
DECLARE_TOAST(pgxc_class, 5506, 5507);
|
||||
DECLARE_TOAST(pg_hashbucket, 4390, 4391);
|
||||
DECLARE_TOAST(gs_model_warehouse, 3995, 3996);
|
||||
|
||||
/* shared catalogs */
|
||||
DECLARE_TOAST(pg_shdescription, 2846, 2847);
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* plannodes.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/dbmind/db4ai/nodes/aifuncs.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef DB4AI_AIFUNCS_H
|
||||
#define DB4AI_AIFUNCS_H
|
||||
|
||||
#include "nodes/plannodes.h"
|
||||
|
||||
inline const char* algorithm_ml_to_string(AlgorithmML x)
|
||||
{
|
||||
switch(x) {
|
||||
case LOGISTIC_REGRESSION: return "logistic_regression";
|
||||
case SVM_CLASSIFICATION: return "svm_classification";
|
||||
case LINEAR_REGRESSION: return "linear_regression";
|
||||
case KMEANS: return "kmeans";
|
||||
case INVALID_ALGORITHM_ML:
|
||||
default: return "INVALID_ALGORITHM_ML";
|
||||
}
|
||||
}
|
||||
|
||||
inline AlgorithmML get_algorithm_ml(const char *str)
|
||||
{
|
||||
if (0 == strcmp(str, "logistic_regression")) {
|
||||
return LOGISTIC_REGRESSION;
|
||||
} else if (0 == strcmp(str, "svm_classification")) {
|
||||
return SVM_CLASSIFICATION;
|
||||
} else if (0 == strcmp(str, "linear_regression")) {
|
||||
return LINEAR_REGRESSION;
|
||||
} else if (0 == strcmp(str, "kmeans")) {
|
||||
return KMEANS;
|
||||
} else {
|
||||
return INVALID_ALGORITHM_ML;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_supervised(AlgorithmML algorithm)
|
||||
{
|
||||
if (algorithm == KMEANS) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // DB4AI_AIFUNCS_H
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*
|
||||
* command.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/dbmind/db4ai/commands/create_model.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef CREATE_MODEL_H
|
||||
#define CREATE_MODEL_H
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
#include "nodes/params.h"
|
||||
#include "nodes/parsenodes.h"
|
||||
#include "nodes/parsenodes_common.h"
|
||||
#include "nodes/plannodes.h"
|
||||
#include "tcop/dest.h"
|
||||
|
||||
struct Model;
|
||||
|
||||
struct DestReceiverTrainModel {
|
||||
DestReceiver dest;
|
||||
Model *model;
|
||||
List *targetlist; // for gradient descent
|
||||
};
|
||||
|
||||
void configure_dest_receiver_train_model(DestReceiverTrainModel *dest, AlgorithmML algorithm, const char *model_name,
|
||||
const char *sql);
|
||||
|
||||
// Create a DestReceiver object for training model operators
|
||||
DestReceiver *CreateTrainModelDestReceiver();
|
||||
|
||||
// Rewrite a create model query, and plan the query. This method is used in query execution
|
||||
// and for explain statements
|
||||
PlannedStmt *plan_create_model(CreateModelStmt *stmt, const char *query_string, ParamListInfo params,
|
||||
DestReceiver *dest);
|
||||
|
||||
// Call executor
|
||||
void exec_create_model(CreateModelStmt *stmt, const char *queryString, ParamListInfo params, char *completionTag);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
---------------------------------------------------------------------------------------
|
||||
|
||||
mrc_cpu.h
|
||||
Code to provide micro-optimizations
|
||||
|
||||
IDENTIFICATION
|
||||
src/include/dbmind/db4ai/executor/mrc_cpu.h
|
||||
|
||||
---------------------------------------------------------------------------------------
|
||||
**/
|
||||
|
||||
#ifndef DB4AI_MRC_CPU_H
|
||||
#define DB4AI_MRC_CPU_H
|
||||
|
||||
#if defined(__GNUC__)
|
||||
|
||||
#ifndef likely
|
||||
#define likely(x) __builtin_expect((x) != 0, 1)
|
||||
#endif
|
||||
|
||||
#ifndef unlikely
|
||||
#define unlikely(x) __builtin_expect((x) != 0, 0)
|
||||
#endif
|
||||
|
||||
#ifndef force_alignment
|
||||
#define force_alignment(x) __attribute__((aligned((x))))
|
||||
#endif
|
||||
|
||||
#ifndef force_inline
|
||||
#define force_inline inline __attribute__((always_inline))
|
||||
#endif
|
||||
|
||||
#ifndef prefetch
|
||||
#define prefetch(address, rw, locality) (__builtin_prefetch(address, rw, locality))
|
||||
#endif
|
||||
|
||||
#else
|
||||
|
||||
#define likely(x) (x)
|
||||
#define unlikely(x) (x)
|
||||
#define prefetch(address, rw, locality) ()
|
||||
#define force_alignment(x) ()
|
||||
#define force_inline ()
|
||||
#endif
|
||||
|
||||
#endif //DB4AI_MRC_CPU_H
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
|
||||
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
---------------------------------------------------------------------------------------
|
||||
|
||||
distance_functions.h
|
||||
Current set of distance functions that can be used (for k-means for example)
|
||||
|
||||
IDENTIFICATION
|
||||
src/include/dbmind/db4ai/executor/distance_functions.h
|
||||
|
||||
---------------------------------------------------------------------------------------
|
||||
**/
|
||||
|
||||
#ifndef DB4AI_DISTANCE_FUNCTIONS_H
|
||||
#define DB4AI_DISTANCE_FUNCTIONS_H
|
||||
|
||||
#include <cinttypes>
|
||||
|
||||
/*
|
||||
* L1 distance (Manhattan)
|
||||
*/
|
||||
extern double l1(double const* p, double const* q, uint32_t dimension);
|
||||
|
||||
/*
|
||||
* Squared Euclidean (default)
|
||||
*/
|
||||
extern double l2_squared(double const* p, double const* q, uint32_t dimension);
|
||||
|
||||
/*
|
||||
* L2 distance (Euclidean)
|
||||
*/
|
||||
extern double l2(double const* p, double const* q, uint32_t dimension);
|
||||
|
||||
/*
|
||||
* L infinity distance (Chebyshev)
|
||||
*/
|
||||
extern double linf(double const* p, double const* q, uint32_t dimension);
|
||||
|
||||
#endif //DB4AI_DISTANCE_FUNCTIONS_H
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
|
||||
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
---------------------------------------------------------------------------------------
|
||||
|
||||
fp_ops.h
|
||||
Robust floating point operations
|
||||
|
||||
IDENTIFICATION
|
||||
src/include/dbmind/db4ai/executor/fp_ops.h
|
||||
|
||||
---------------------------------------------------------------------------------------
|
||||
**/
|
||||
|
||||
#ifndef DB4AI_FP_OPS_H
|
||||
#define DB4AI_FP_OPS_H
|
||||
|
||||
/*
|
||||
* High precision sum: a + b = *sum + *e
|
||||
*/
|
||||
extern void twoSum(double a, double b, double* sum, double* e);
|
||||
|
||||
/*
|
||||
* The equivalent subtraction a - b = *sub + *e
|
||||
*/
|
||||
extern void twoDiff(double a, double b, double* sub, double* e);
|
||||
|
||||
/*
|
||||
* High precision product a * b = *mult + *e
|
||||
*/
|
||||
extern void twoMult(double a, double b, double* mult, double* e);
|
||||
|
||||
/*
|
||||
* High precision square a * a = *square + *e (faster than twoMult(a, a,..))
|
||||
*/
|
||||
extern void square(double a, double* square, double* e);
|
||||
|
||||
/*
|
||||
* High precision division a / b = *div + *e
|
||||
*/
|
||||
extern void twoDiv(double a, double b, double* div, double* e);
|
||||
|
||||
#endif //DB4AI_FP_OPS_H
|
|
@ -0,0 +1,163 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* gd.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/db4ai/gd.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef GD_H
|
||||
#define GD_H
|
||||
|
||||
#include "nodes/execnodes.h"
|
||||
#include "db4ai/predict_by.h"
|
||||
|
||||
// returns the current time in microseconds
|
||||
inline uint64_t
|
||||
gd_get_clock_usecs() {
|
||||
struct timeval tv;
|
||||
gettimeofday(&tv, NULL);
|
||||
return tv.tv_sec * 1000000ULL + tv.tv_usec;
|
||||
}
|
||||
|
||||
enum {
|
||||
GD_DEPENDENT_VAR_CONTINUOUS = 0x00000000,
|
||||
GD_DEPENDENT_VAR_BINARY = 0x00000001,
|
||||
GD_DEPENDENT_VAR_CATEGORICAL = 0x00000002,
|
||||
};
|
||||
|
||||
enum {
|
||||
METRIC_ACCURACY = 0x0001, // (tp + tn) / n
|
||||
METRIC_F1 = 0x0002, // 2 * (precision * recall) / (precision + recall)
|
||||
METRIC_PRECISION = 0x0004, // tp / (tp + fp)
|
||||
METRIC_RECALL = 0x0008, // tp / (tp + fn)
|
||||
METRIC_LOSS = 0x0010, // defined by each algorithm
|
||||
METRIC_MSE = 0x0020, // sum((y-y')^2)) / n
|
||||
};
|
||||
|
||||
char* gd_get_metric_name(int metric);
|
||||
|
||||
struct GradientDescent;
|
||||
|
||||
typedef void (*f_gd_gradients)(const GradientDescent* gd_node, const Matrix* features, const Matrix* dep_var,
|
||||
Matrix* weights, Matrix* gradients);
|
||||
typedef double (*f_gd_test)(const GradientDescent* gd_node, const Matrix* features, const Matrix* dep_var,
|
||||
const Matrix* weights, Scores* scores);
|
||||
typedef gd_float (*f_gd_predict)(const Matrix* features, const Matrix* weights);
|
||||
|
||||
typedef struct GradientDescentAlgorithm {
|
||||
const char* name;
|
||||
int flags;
|
||||
int metrics;
|
||||
// values for binary algorithms
|
||||
// e.g. (0,1) for logistic regression or (-1,1) for svm classifier
|
||||
gd_float min_class;
|
||||
gd_float max_class;
|
||||
// callbacks for hooks
|
||||
f_gd_gradients gradients_callback;
|
||||
f_gd_test test_callback;
|
||||
f_gd_predict predict_callback;
|
||||
} GradientDescentAlgorithm;
|
||||
|
||||
GradientDescentAlgorithm* gd_get_algorithm(AlgorithmML algorithm);
|
||||
|
||||
inline bool dep_var_is_continuous(const GradientDescentAlgorithm* algo) {
|
||||
return ((algo->flags & (GD_DEPENDENT_VAR_BINARY | GD_DEPENDENT_VAR_CATEGORICAL)) == 0);
|
||||
}
|
||||
|
||||
inline bool dep_var_is_binary(const GradientDescentAlgorithm* algo) {
|
||||
return ((algo->flags & GD_DEPENDENT_VAR_BINARY) != 0);
|
||||
}
|
||||
|
||||
gd_float gd_datum_get_float(Oid type, Datum datum);
|
||||
Datum gd_float_get_datum(Oid type, gd_float value);
|
||||
|
||||
inline Oid
|
||||
get_atttypid(GradientDescentState* gd_state, int col) {
|
||||
return gd_state->tupdesc->attrs[col]->atttypid;
|
||||
}
|
||||
|
||||
inline bool
|
||||
get_attbyval(GradientDescentState* gd_state, int col) {
|
||||
return gd_state->tupdesc->attrs[col]->attbyval;
|
||||
}
|
||||
|
||||
inline int
|
||||
get_attlen(GradientDescentState* gd_state, int col) {
|
||||
return gd_state->tupdesc->attrs[col]->attlen;
|
||||
}
|
||||
|
||||
inline int
|
||||
get_atttypmod(GradientDescentState* gd_state, int col) {
|
||||
return gd_state->tupdesc->attrs[col]->atttypmod;
|
||||
}
|
||||
|
||||
inline int
|
||||
get_natts(GradientDescentState* gd_state) {
|
||||
return gd_state->tupdesc->natts;
|
||||
}
|
||||
|
||||
struct Model;
|
||||
|
||||
ModelPredictor gd_predict_prepare(const Model* model);
|
||||
Datum gd_predict(ModelPredictor pred, Datum* values, bool* isnull, Oid* types, int values_size);
|
||||
|
||||
inline const GradientDescent* gd_get_node(const GradientDescentState* gd_state) {
|
||||
return (const GradientDescent*)gd_state->ss.ps.plan;
|
||||
}
|
||||
|
||||
const char* gd_get_expr_name(GradientDescentExprField field);
|
||||
Datum ExecEvalGradientDescent(GradientDescentExprState* mlstate, ExprContext* econtext,
|
||||
bool* isNull, ExprDoneCond* isDone);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// optimizers
|
||||
|
||||
typedef struct OptimizerGD {
|
||||
// shared members, initialized by the caller and not by the specialization
|
||||
Matrix weights;
|
||||
Matrix gradients;
|
||||
|
||||
// specialized methods
|
||||
void (*start_iteration)(OptimizerGD* optimizer);
|
||||
void (*update_batch)(OptimizerGD* optimizer, const Matrix* features, const Matrix* dep_var);
|
||||
void (*end_iteration)(OptimizerGD* optimizer);
|
||||
void (*release)(OptimizerGD* optimizer);
|
||||
} OptimizerGD;
|
||||
|
||||
OptimizerGD* gd_init_optimizer_gd(const GradientDescentState* gd_state);
|
||||
OptimizerGD* gd_init_optimizer_ngd(const GradientDescentState* gd_state);
|
||||
|
||||
const char* gd_get_optimizer_name(OptimizerML optimizer);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// shuffle
|
||||
|
||||
typedef struct ShuffleGD {
|
||||
OptimizerGD* optimizer;
|
||||
void (*start_iteration)(ShuffleGD* shuffle);
|
||||
Matrix* (*get)(ShuffleGD* shuffle, Matrix** dep_var);
|
||||
void (*unget)(ShuffleGD* shuffle, int tuples);
|
||||
void (*end_iteration)(ShuffleGD* shuffle);
|
||||
void (*release)(ShuffleGD* shuffle);
|
||||
} ShuffleGD;
|
||||
|
||||
ShuffleGD* gd_init_shuffle_cache(const GradientDescentState* gd_state);
|
||||
|
||||
|
||||
#endif /* GD_H */
|
|
@ -0,0 +1,125 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
*
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/db4ai/hyperparameter_validation.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
|
||||
#ifndef DB4AI_HYPERPARAMETER_VALIDATION_H
|
||||
#define DB4AI_HYPERPARAMETER_VALIDATION_H
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
#include "db4ai/model_warehouse.h"
|
||||
#include "nodes/pg_list.h"
|
||||
#include "utils/builtins.h"
|
||||
|
||||
|
||||
#include <float.h>
|
||||
|
||||
|
||||
struct HyperparameterDefinition {
|
||||
const char* name;
|
||||
Datum default_value;
|
||||
Datum min_value;
|
||||
Datum max_value;
|
||||
const char** valid_values;
|
||||
void (* enum_setter)(const char* s, void* enum_addr);
|
||||
Oid type;
|
||||
int32_t valid_values_size;
|
||||
int32_t offset;
|
||||
bool min_inclusive;
|
||||
bool max_inclusive;
|
||||
};
|
||||
|
||||
struct HyperparameterValidation {
|
||||
void* min_value;
|
||||
bool min_inclusive;
|
||||
void* max_value;
|
||||
bool max_inclusive;
|
||||
const char** valid_values;
|
||||
int32_t valid_values_size;
|
||||
};
|
||||
|
||||
// changes the final value of a hyperparameter into the model warehouse output
|
||||
void update_model_hyperparameter(Model* model, const char* name, Oid type, Datum value);
|
||||
|
||||
void configure_hyperparameters(AlgorithmML algorithm,
|
||||
List* hyperparameters, Model* model, void* hyperparameter_struct);
|
||||
|
||||
// Set int hyperparameter
|
||||
void set_hyperparameter_value(const char* name, int* hyperparameter,
|
||||
Value* value, VariableSetKind kind, int default_value, Model* model,
|
||||
HyperparameterValidation* validation);
|
||||
|
||||
|
||||
// Set double hyperparameter
|
||||
void set_hyperparameter_value(const char* name, double* hyperparameter, Value* value,
|
||||
VariableSetKind kind, double default_value, Model* model,
|
||||
HyperparameterValidation* validation);
|
||||
|
||||
|
||||
// Set string hyperparameter (no const)
|
||||
void set_hyperparameter_value(const char* name, char** hyperparameter, Value* value,
|
||||
VariableSetKind kind, char* default_value, Model* model,
|
||||
HyperparameterValidation* validation);
|
||||
|
||||
|
||||
// Set boolean hyperparameter
|
||||
void set_hyperparameter_value(const char* name, bool* hyperparameter,
|
||||
Value* value, VariableSetKind kind, bool default_value, Model* model,
|
||||
HyperparameterValidation* validation);
|
||||
|
||||
|
||||
// General purpouse method to set the hyperparameters
|
||||
// Locate the hyperparameter in the list by name and set it to the selected value
|
||||
// Return the index in the list of the hyperparameters. If not found return -1
|
||||
template<typename T>
|
||||
int set_hyperparameter(const char* name, T* hyperparameter, List* hyperparameters, T default_value, Model* model,
|
||||
HyperparameterValidation* validation) {
|
||||
int result = 0;
|
||||
foreach_cell(it, hyperparameters) {
|
||||
VariableSetStmt* current = lfirst_node(VariableSetStmt, it);
|
||||
|
||||
if (strcmp(current->name, name) == 0) {
|
||||
if (list_length(current->args) > 1) {
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("Hyperparameter %s cannot be a list", current->name)));
|
||||
}
|
||||
|
||||
Value* value = NULL;
|
||||
if (current->args != NULL) {
|
||||
A_Const* aconst = NULL;
|
||||
aconst = linitial_node(A_Const, current->args);
|
||||
value = &aconst->val;
|
||||
}
|
||||
|
||||
set_hyperparameter_value(name, hyperparameter, value, current->kind, default_value, model, validation);
|
||||
return result;
|
||||
}
|
||||
result++;
|
||||
}
|
||||
|
||||
// If not set by user, set the default value
|
||||
set_hyperparameter_value(name, hyperparameter, NULL, VAR_SET_DEFAULT, default_value, model, validation);
|
||||
return -1;
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,527 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* matrix.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/dbmind/db4ai/executor/gd/matrix.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef MATRIX_H
|
||||
#define MATRIX_H
|
||||
|
||||
#include "postgres.h"
|
||||
#include "lib/stringinfo.h"
|
||||
#include "db4ai/scores.h"
|
||||
|
||||
#include "math.h"
|
||||
#include "float.h"
|
||||
|
||||
#define MATRIX_CACHE 16
|
||||
|
||||
typedef float4 gd_float;
|
||||
|
||||
typedef struct Matrix {
|
||||
int rows;
|
||||
int columns;
|
||||
bool transposed;
|
||||
int allocated;
|
||||
gd_float *data;
|
||||
gd_float cache[MATRIX_CACHE];
|
||||
} Matrix;
|
||||
|
||||
int matrix_expected_size(int rows, int columns = 1);
|
||||
|
||||
// initializes a bidimensional matrix or a vector (#columns = 1) to zeros
|
||||
void matrix_init(Matrix *matrix, int rows, int columns = 1);
|
||||
|
||||
// initializes with a copy of a matrix
|
||||
void matrix_init_clone(Matrix *matrix, const Matrix *src);
|
||||
|
||||
// initializes with a transposes a matrix, it is a virtual operation
|
||||
void matrix_init_transpose(Matrix *matrix, const Matrix *src);
|
||||
|
||||
// releases the memory of a matrix and makes it emmpty
|
||||
void matrix_release(Matrix *matrix);
|
||||
|
||||
// fiils a matrix with zeroes
|
||||
void matrix_zeroes(Matrix *matrix);
|
||||
|
||||
// changes the shape of a matrix
|
||||
// only number of rows will be done later
|
||||
void matrix_resize(Matrix *matrix, int rows, int columns);
|
||||
|
||||
// multiplies a matrix by a vector, row by row
|
||||
void matrix_mult_vector(const Matrix *matrix, const Matrix *vector, Matrix *result);
|
||||
|
||||
// multiple two matrices by coefficients (a.k.a. Hadamard or Schur product)
|
||||
void matrix_mult_entrywise(Matrix *m1, const Matrix *m2);
|
||||
|
||||
// multiplies a matrix by a scalar, coeeficient by coefficient
|
||||
void matrix_mult_scalar(Matrix *matrix, gd_float factor);
|
||||
|
||||
// divides a matrix by a scalar, coeeficient by coefficient
|
||||
void matrix_divide(Matrix *matrix, gd_float factor);
|
||||
|
||||
// adds a matrix by another matrix, coeeficient by coefficient
|
||||
void matrix_add(Matrix *m1, const Matrix *m2);
|
||||
|
||||
// subtracts a matrix by another matrix, coeeficient by coefficient
|
||||
void matrix_subtract(Matrix *m1, const Matrix *m2);
|
||||
|
||||
// squares all coefficients
|
||||
void matrix_square(Matrix *matrix);
|
||||
|
||||
// obtains the square root of all coefficients
|
||||
void matrix_square_root(Matrix *matrix);
|
||||
|
||||
// computes the sigmoid of all coefficients: 1.0 / (1.0 + exp(-c))
|
||||
void matrix_sigmoid(Matrix *matrix);
|
||||
|
||||
// computes the natural logarithm of all coefficients
|
||||
void matrix_log(Matrix *matrix);
|
||||
|
||||
// computes the natural logarithm log(1+c) of all coefficients
|
||||
void matrix_log1p(Matrix *matrix);
|
||||
|
||||
// negates all coefficients (-c)
|
||||
void matrix_negate(Matrix *matrix);
|
||||
|
||||
// complements all coefficients (1-c)
|
||||
void matrix_complement(Matrix *matrix);
|
||||
|
||||
// make sure all coeeficients are c>=0
|
||||
void matrix_positive(Matrix *matrix);
|
||||
|
||||
// return the sum of all coefficients
|
||||
gd_float matrix_get_sum(Matrix *matrix);
|
||||
|
||||
// scales a matrix row by row, using two vectors ( N & D)
|
||||
// where each coefficient is scale c'=(c-N)/D
|
||||
// - normalization: N=min, D=(max-min)
|
||||
// - standardization: N=avg, D=stdev
|
||||
void matrix_scale(Matrix *matrix, const Matrix *m_n, const Matrix *m_d);
|
||||
|
||||
// computes the dot product of two vectors
|
||||
gd_float matrix_dot(const Matrix *v1, const Matrix *v2);
|
||||
|
||||
// converts all coefficients to binary values w.r.t. a threshold
|
||||
// low: v<threshold; high: v>=threshold
|
||||
void matrix_binary(Matrix *matrix, gd_float threshold, gd_float low, gd_float high);
|
||||
|
||||
// compares two binary vectors
|
||||
void matrix_relevance(const Matrix *v1, const Matrix *v2, Scores *scores, gd_float positive);
|
||||
|
||||
// prints to a buffer
|
||||
void matrix_print(const Matrix *matrix, StringInfo buf, bool full);
|
||||
|
||||
// prints into the log
|
||||
void elog_matrix(int elevel, const char *msg, const Matrix *matrix);
|
||||
|
||||
// ///////////////////////////////////////////////////////////////////////////
|
||||
// inline
|
||||
|
||||
inline int matrix_expected_size(int rows, int columns)
|
||||
{
|
||||
Assert(rows > 0);
|
||||
Assert(columns > 0);
|
||||
int cells = rows * columns;
|
||||
if (cells <= MATRIX_CACHE)
|
||||
cells = 0; // cached, no extra memory
|
||||
|
||||
return sizeof(Matrix) + cells * sizeof(gd_float);
|
||||
}
|
||||
|
||||
inline void matrix_init(Matrix *matrix, int rows, int columns)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(rows > 0);
|
||||
Assert(columns > 0);
|
||||
matrix->transposed = false;
|
||||
matrix->rows = rows;
|
||||
matrix->columns = columns;
|
||||
matrix->allocated = rows * columns;
|
||||
if (matrix->allocated <= MATRIX_CACHE) {
|
||||
matrix->data = matrix->cache;
|
||||
errno_t rc = memset_s(matrix->data, MATRIX_CACHE * sizeof(gd_float), 0, matrix->allocated * sizeof(gd_float));
|
||||
securec_check(rc, "", "");
|
||||
} else
|
||||
matrix->data = (gd_float *)palloc0(matrix->allocated * sizeof(gd_float));
|
||||
}
|
||||
|
||||
inline void matrix_init_clone(Matrix *matrix, const Matrix *src)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(src != nullptr);
|
||||
Assert(!src->transposed);
|
||||
matrix_init(matrix, src->rows, src->columns);
|
||||
size_t bytes = src->rows * src->columns * sizeof(gd_float);
|
||||
errno_t rc = memcpy_s(matrix->data, bytes, src->data, bytes);
|
||||
securec_check(rc, "", "");
|
||||
}
|
||||
|
||||
// fake transpose, only points to the data of the other matrix
|
||||
inline void matrix_init_transpose(Matrix *matrix, const Matrix *src)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(src != nullptr);
|
||||
Assert(!src->transposed);
|
||||
matrix->transposed = true;
|
||||
matrix->rows = src->columns;
|
||||
matrix->columns = src->rows;
|
||||
matrix->allocated = 0;
|
||||
matrix->data = src->data;
|
||||
}
|
||||
|
||||
inline void matrix_release(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
if (matrix->allocated > 0) {
|
||||
Assert(matrix->data != nullptr);
|
||||
if (matrix->data != matrix->cache)
|
||||
pfree(matrix->data);
|
||||
|
||||
matrix->allocated = 0;
|
||||
}
|
||||
matrix->data = nullptr;
|
||||
matrix->rows = 0;
|
||||
matrix->columns = 0;
|
||||
}
|
||||
|
||||
inline void matrix_zeroes(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(!matrix->transposed);
|
||||
errno_t rc = memset_s(matrix->data, sizeof(gd_float) * matrix->allocated, 0,
|
||||
sizeof(gd_float) * matrix->rows * matrix->columns);
|
||||
securec_check(rc, "", "");
|
||||
}
|
||||
|
||||
inline void matrix_resize(Matrix *matrix, int rows, int columns)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(!matrix->transposed);
|
||||
Assert(rows > 0);
|
||||
Assert(columns > 0);
|
||||
if (columns != matrix->columns)
|
||||
ereport(ERROR,
|
||||
(errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED), errmsg("resize column not yet supported")));
|
||||
|
||||
if (rows != matrix->rows) {
|
||||
if (rows > matrix->rows) {
|
||||
int required = rows * matrix->columns;
|
||||
if (required > matrix->allocated)
|
||||
ereport(ERROR, (errmodule(MOD_DB4AI), errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
|
||||
errmsg("matrix growth not yet supported")));
|
||||
}
|
||||
matrix->rows = rows;
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_mult_vector(const Matrix *matrix, const Matrix *vector, Matrix *result)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(vector != nullptr);
|
||||
Assert(vector->columns == 1);
|
||||
Assert(!vector->transposed);
|
||||
Assert(result != nullptr);
|
||||
Assert(!result->transposed);
|
||||
Assert(result->columns == 1);
|
||||
Assert(matrix->rows == result->rows);
|
||||
Assert(matrix->columns == vector->rows);
|
||||
|
||||
if (matrix->transposed) {
|
||||
gd_float *pd = result->data;
|
||||
// loop assumes that the data has not been physically transposed
|
||||
for (int r = 0; r < matrix->rows; r++) {
|
||||
const gd_float *pm = matrix->data + r;
|
||||
const gd_float *pv = vector->data;
|
||||
gd_float x = 0.0;
|
||||
for (int c = 0; c < matrix->columns; c++) {
|
||||
x += *pm * *pv++;
|
||||
pm += matrix->rows;
|
||||
}
|
||||
*pd++ = x;
|
||||
}
|
||||
} else {
|
||||
const gd_float *pm = matrix->data;
|
||||
gd_float *pd = result->data;
|
||||
for (int r = 0; r < matrix->rows; r++) {
|
||||
const gd_float *pv = vector->data;
|
||||
size_t count = matrix->columns;
|
||||
gd_float x = 0.0;
|
||||
while (count-- > 0)
|
||||
x += *pv++ * *pm++;
|
||||
*pd++ = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_mult_entrywise(Matrix *m1, const Matrix *m2)
|
||||
{
|
||||
Assert(m1 != nullptr);
|
||||
Assert(!m1->transposed);
|
||||
Assert(m2 != nullptr);
|
||||
Assert(!m2->transposed);
|
||||
Assert(m1->rows == m2->rows);
|
||||
Assert(m1->columns == m2->columns);
|
||||
|
||||
size_t count = m1->rows * m1->columns;
|
||||
gd_float *pd = m1->data;
|
||||
const gd_float *ps = m2->data;
|
||||
while (count-- > 0)
|
||||
*pd++ *= *ps++;
|
||||
}
|
||||
|
||||
inline void matrix_mult_scalar(Matrix *matrix, gd_float factor)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0)
|
||||
*pd++ *= factor;
|
||||
}
|
||||
|
||||
inline void matrix_divide(Matrix *matrix, gd_float factor)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(factor != 0.0);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0)
|
||||
*pd++ /= factor;
|
||||
}
|
||||
|
||||
inline void matrix_add(Matrix *m1, const Matrix *m2)
|
||||
{
|
||||
Assert(m1 != nullptr);
|
||||
Assert(!m1->transposed);
|
||||
Assert(m2 != nullptr);
|
||||
Assert(!m2->transposed);
|
||||
Assert(m1->rows == m2->rows);
|
||||
Assert(m1->columns == m2->columns);
|
||||
size_t count = m1->rows * m1->columns;
|
||||
gd_float *p1 = m1->data;
|
||||
const gd_float *p2 = m2->data;
|
||||
while (count-- > 0)
|
||||
*p1++ += *p2++;
|
||||
}
|
||||
|
||||
inline void matrix_subtract(Matrix *m1, const Matrix *m2)
|
||||
{
|
||||
Assert(m1 != nullptr);
|
||||
Assert(!m1->transposed);
|
||||
Assert(m2 != nullptr);
|
||||
Assert(!m2->transposed);
|
||||
Assert(m1->rows == m2->rows);
|
||||
Assert(m1->columns == m2->columns);
|
||||
size_t count = m1->rows * m1->columns;
|
||||
gd_float *p1 = m1->data;
|
||||
const gd_float *p2 = m2->data;
|
||||
while (count-- > 0)
|
||||
*p1++ -= *p2++;
|
||||
}
|
||||
|
||||
inline gd_float matrix_dot(const Matrix *v1, const Matrix *v2)
|
||||
{
|
||||
Assert(v1 != nullptr);
|
||||
Assert(!v1->transposed);
|
||||
Assert(v2 != nullptr);
|
||||
Assert(!v2->transposed);
|
||||
Assert(v1->rows == v2->rows);
|
||||
Assert(v1->columns == 1);
|
||||
Assert(v2->columns == 1);
|
||||
|
||||
size_t count = v1->rows;
|
||||
const gd_float *p1 = v1->data;
|
||||
const gd_float *p2 = v2->data;
|
||||
gd_float result = 0;
|
||||
while (count-- > 0)
|
||||
result += *p1++ * *p2++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline void matrix_square(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
*pd *= *pd;
|
||||
pd++;
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_square_root(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
*pd = sqrt(*pd);
|
||||
pd++;
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_sigmoid(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
gd_float c = *pd;
|
||||
*pd++ = 1.0 / (1.0 + exp(-c));
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_log(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
gd_float v = *pd;
|
||||
*pd++ = log(v);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_log1p(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
gd_float v = *pd + 1;
|
||||
*pd++ = log(v);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_negate(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
gd_float v = *pd;
|
||||
*pd++ = -v;
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_complement(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
gd_float v = 1.0 - *pd;
|
||||
*pd++ = v;
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_positive(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
gd_float v = *pd;
|
||||
if (v < 0.0)
|
||||
*pd = 0.0;
|
||||
pd++;
|
||||
}
|
||||
}
|
||||
|
||||
inline gd_float matrix_get_sum(Matrix *matrix)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *ps = matrix->data;
|
||||
gd_float s = 0.0;
|
||||
while (count-- > 0)
|
||||
s += *ps++;
|
||||
return s;
|
||||
}
|
||||
|
||||
inline void matrix_scale(Matrix *matrix, const Matrix *m_n, const Matrix *m_d)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
Assert(m_n != nullptr);
|
||||
Assert(m_d != nullptr);
|
||||
Assert(!matrix->transposed);
|
||||
Assert(!m_n->transposed);
|
||||
Assert(matrix->columns == m_n->rows);
|
||||
Assert(m_n->columns == 1);
|
||||
Assert(!m_d->transposed);
|
||||
Assert(matrix->columns == m_d->rows);
|
||||
Assert(m_d->columns == 1);
|
||||
gd_float *pd = matrix->data;
|
||||
for (int r = 0; r < matrix->rows; r++) {
|
||||
const gd_float *p1 = m_n->data;
|
||||
const gd_float *p2 = m_d->data;
|
||||
for (int c = 0; c < matrix->columns; c++) {
|
||||
*pd = (*pd - *p1++) / *p2++;
|
||||
pd++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_binary(Matrix *matrix, gd_float threshold, gd_float low, gd_float high)
|
||||
{
|
||||
Assert(matrix != nullptr);
|
||||
size_t count = matrix->rows * matrix->columns;
|
||||
gd_float *pd = matrix->data;
|
||||
while (count-- > 0) {
|
||||
gd_float v = *pd;
|
||||
*pd++ = (v < threshold ? low : high);
|
||||
}
|
||||
}
|
||||
|
||||
inline void matrix_relevance(const Matrix *v1, const Matrix *v2, Scores *scores, gd_float positive)
|
||||
{
|
||||
Assert(v1 != nullptr);
|
||||
Assert(!v1->transposed);
|
||||
Assert(v2 != nullptr);
|
||||
Assert(!v2->transposed);
|
||||
Assert(v1->rows == v2->rows);
|
||||
Assert(v1->columns == 1);
|
||||
Assert(v2->columns == 1);
|
||||
|
||||
size_t count = v1->rows;
|
||||
const gd_float *p1 = v1->data;
|
||||
const gd_float *p2 = v2->data;
|
||||
while (count-- > 0) {
|
||||
gd_float x = *p1++;
|
||||
gd_float y = *p2++;
|
||||
if (x == positive) {
|
||||
// positive
|
||||
if (y == positive)
|
||||
scores->tp++;
|
||||
else
|
||||
scores->fp++;
|
||||
} else {
|
||||
// negative
|
||||
if (y != positive)
|
||||
scores->tn++;
|
||||
else
|
||||
scores->fn++;
|
||||
}
|
||||
scores->count++;
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* MATRIX_H */
|
|
@ -0,0 +1,112 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* command.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/gausskernel/catalog/model_warehouse.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef DB4AI_MODEL_WAREHOUSE
|
||||
#define DB4AI_MODEL_WAREHOUSE
|
||||
|
||||
#include "postgres.h"
|
||||
#include "nodes/pg_list.h"
|
||||
#include "nodes/plannodes.h"
|
||||
#include "nodes/execnodes.h"
|
||||
|
||||
struct Hyperparameter{
|
||||
const char* name;
|
||||
Oid type;
|
||||
Datum value;
|
||||
};
|
||||
|
||||
struct TrainingInfo{
|
||||
const char* name;
|
||||
Oid type;
|
||||
Datum value;
|
||||
};
|
||||
|
||||
struct TrainingScore{
|
||||
const char* name;
|
||||
double value;
|
||||
};
|
||||
|
||||
// Base class for models
|
||||
struct Model{
|
||||
AlgorithmML algorithm;
|
||||
const char* model_name;
|
||||
const char* sql;
|
||||
double exec_time_secs;
|
||||
double pre_time_secs;
|
||||
int64_t processed_tuples;
|
||||
int64_t discarded_tuples;
|
||||
List* train_info; // List of TrainingInfo
|
||||
List* hyperparameters; // List of Hyperparamters
|
||||
List* scores; // List of TrainingScore
|
||||
Oid return_type; // Return type of the model
|
||||
int32_t num_actual_iterations;
|
||||
};
|
||||
|
||||
// Used by all GradientDescent variants
|
||||
struct ModelGradientDescent{
|
||||
Model model;
|
||||
Datum weights; // Float[]
|
||||
int ncategories; // 0 for continuous
|
||||
Datum categories; // only for categorical, an array of return_type[ncategories]
|
||||
};
|
||||
|
||||
// Used by K-Means models
|
||||
typedef struct WHCentroid {
|
||||
double objective_function = DBL_MAX;
|
||||
double avg_distance_to_centroid = DBL_MAX;
|
||||
double min_distance_to_centroid = DBL_MAX;
|
||||
double max_distance_to_centroid = DBL_MAX;
|
||||
double std_dev_distance_to_centroid = DBL_MAX;
|
||||
uint64_t cluster_size = 0ULL;
|
||||
double* coordinates = nullptr;
|
||||
uint32_t id = 0U;
|
||||
} WHCentroid;
|
||||
|
||||
struct ModelKMeans {
|
||||
Model model;
|
||||
/*
|
||||
* the following fields are put here for convenience
|
||||
*/
|
||||
uint32_t original_num_centroids = 0U;
|
||||
uint32_t actual_num_centroids = 0U;
|
||||
uint32_t dimension = 0U;
|
||||
uint32_t distance_function_id = 0U;
|
||||
uint64_t seed = 0ULL;
|
||||
WHCentroid* centroids = nullptr;
|
||||
};
|
||||
|
||||
// Used by XGBoost
|
||||
struct ModelBinary {
|
||||
Model model;
|
||||
uint64_t model_len;
|
||||
Datum model_data; // varlena (void*)
|
||||
};
|
||||
|
||||
|
||||
// Store the model in the catalog tables
|
||||
void store_model(const Model* model);
|
||||
|
||||
// Get the model from the catalog tables
|
||||
Model* get_model(const char* model_name, bool only_model);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/db4ai/predict_by.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef DB4AI_PREDICT_BY_H
|
||||
#define DB4AI_PREDICT_BY_H
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
#include "db4ai/model_warehouse.h"
|
||||
#include "nodes/parsenodes_common.h"
|
||||
#include "fmgr/fmgr_comp.h"
|
||||
#include "parser/parse_node.h"
|
||||
|
||||
|
||||
|
||||
typedef void* ModelPredictor; // Deserialized version of the model that can compute efficiently predictions
|
||||
|
||||
struct PredictorInterface {
|
||||
ModelPredictor (*prepare) (const Model* model);
|
||||
Datum (*predict) (ModelPredictor pred, Datum* values, bool* isnull, Oid* types, int values_size);
|
||||
};
|
||||
|
||||
Datum db4ai_predict_by_bool(PG_FUNCTION_ARGS);
|
||||
Datum db4ai_predict_by_int32(PG_FUNCTION_ARGS);
|
||||
Datum db4ai_predict_by_int64(PG_FUNCTION_ARGS);
|
||||
Datum db4ai_predict_by_float4(PG_FUNCTION_ARGS);
|
||||
Datum db4ai_predict_by_float8(PG_FUNCTION_ARGS);
|
||||
Datum db4ai_predict_by_numeric(PG_FUNCTION_ARGS);
|
||||
Datum db4ai_predict_by_text(PG_FUNCTION_ARGS);
|
||||
|
||||
|
||||
#endif
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* scores.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/dbmind/db4ai/executor/gd/scores.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef SCORES_H
|
||||
#define SCORES_H
|
||||
|
||||
#include "postgres.h"
|
||||
|
||||
typedef struct Scores {
|
||||
int count;
|
||||
// for classification or binary regression
|
||||
int tp; // true positive
|
||||
int tn; // true negative
|
||||
int fp; // false positive
|
||||
int fn; // false negative
|
||||
// for continuous regression
|
||||
float mse; // mean squared error
|
||||
} Scores;
|
||||
|
||||
inline void scores_init(Scores *scores)
|
||||
{
|
||||
errno_t rc = memset_s(scores, sizeof(Scores), 0, sizeof(Scores));
|
||||
securec_check(rc, "", "");
|
||||
}
|
||||
|
||||
// (tp + tn) / n
|
||||
inline double get_accuracy(const Scores *scores)
|
||||
{
|
||||
return (scores->tp + scores->tn) / (double)scores->count;
|
||||
}
|
||||
|
||||
// tp / (tp + fp)
|
||||
inline double get_precision(const Scores *scores, bool *has)
|
||||
{
|
||||
double d = scores->tp + scores->fp;
|
||||
if (d > 0) {
|
||||
*has = true;
|
||||
return scores->tp / d;
|
||||
}
|
||||
*has = false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// tp / (tp + fn)
|
||||
inline double get_recall(const Scores *scores, bool *has)
|
||||
{
|
||||
double d = scores->tp + scores->fn;
|
||||
if (d > 0) {
|
||||
*has = true;
|
||||
return scores->tp / d;
|
||||
}
|
||||
*has = false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif /* SCORES_H */
|
|
@ -0,0 +1,38 @@
|
|||
/*
|
||||
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
*
|
||||
* openGauss is licensed under Mulan PSL v2.
|
||||
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
* You may obtain a copy of Mulan PSL v2 at:
|
||||
*
|
||||
* http://license.coscl.org.cn/MulanPSL2
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
* See the Mulan PSL v2 for more details.
|
||||
*---------------------------------------------------------------------------------------
|
||||
*
|
||||
* nodeGD.h
|
||||
*
|
||||
* IDENTIFICATION
|
||||
* src/include/executor/nodeGD.h
|
||||
*
|
||||
* ---------------------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef NODE_GD_H
|
||||
#define NODE_GD_H
|
||||
|
||||
#include "nodes/execnodes.h"
|
||||
|
||||
extern GradientDescentState* ExecInitGradientDescent(GradientDescent* node, EState* estate, int eflags);
|
||||
extern TupleTableSlot* ExecGradientDescent(GradientDescentState* state);
|
||||
extern void ExecEndGradientDescent(GradientDescentState* state);
|
||||
|
||||
typedef void (*GradientDescentHook_iteration)(GradientDescentState* state);
|
||||
extern GradientDescentHook_iteration gdhook_iteration;
|
||||
|
||||
List* makeGradientDescentExpr(AlgorithmML algorithm, List* list, int field);
|
||||
|
||||
#endif /* NODE_GD_H */
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
Copyright (c) 2021 Huawei Technologies Co.,Ltd.
|
||||
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
---------------------------------------------------------------------------------------
|
||||
|
||||
nodeKMeans.h
|
||||
Functions related to the k-means operator
|
||||
|
||||
IDENTIFICATION
|
||||
src/include/executor/nodeKMeans.h
|
||||
|
||||
---------------------------------------------------------------------------------------
|
||||
**/
|
||||
|
||||
#ifndef DB4AI_NODEKMEANS_H
|
||||
#define DB4AI_NODEKMEANS_H
|
||||
|
||||
#include "nodes/execnodes.h"
|
||||
|
||||
/*
|
||||
* do not touch this unless you know what you're doing (the implications on
|
||||
* nodeKMeans.cpp and create_model.cpp). you are warned.
|
||||
*/
|
||||
uint32_t constexpr NUM_ATTR_OUTPUT = 15U;
|
||||
|
||||
extern KMeansState* ExecInitKMeans(KMeans* node, EState* estate, int eflags);
|
||||
|
||||
extern TupleTableSlot* ExecKMeans(KMeansState* node);
|
||||
|
||||
extern void ExecEndKMeans(KMeansState* node);
|
||||
|
||||
#endif //DB4AI_NODEKMEANS_H
|
|
@ -21,6 +21,7 @@
|
|||
#include "nodes/params.h"
|
||||
#include "nodes/plannodes.h"
|
||||
#include "storage/pagecompress.h"
|
||||
#include "utils/array.h"
|
||||
#include "utils/bloom_filter.h"
|
||||
#include "utils/reltrigger.h"
|
||||
#include "utils/sortsupport.h"
|
||||
|
@ -28,6 +29,8 @@
|
|||
#include "utils/tuplestore.h"
|
||||
#include "vecexecutor/vectorbatch.h"
|
||||
|
||||
#include "db4ai/matrix.h"
|
||||
|
||||
#ifdef ENABLE_MOT
|
||||
// forward declaration for MOT JitContext
|
||||
namespace JitExec
|
||||
|
@ -2483,5 +2486,133 @@ inline bool BitmapNodeNeedSwitchPartRel(BitmapHeapScanState* node)
|
|||
|
||||
extern RangeScanInRedis reset_scan_qual(Relation currHeapRel, ScanState *node);
|
||||
|
||||
/*
|
||||
* DB4AI GD node: used for train models using Gradient Descent
|
||||
*/
|
||||
|
||||
struct GradientDescentAlgorithm;
|
||||
struct GradientDescentExpr;
|
||||
struct OptimizerGD;
|
||||
struct ShuffleGD;
|
||||
|
||||
typedef struct GradientDescentState {
|
||||
ScanState ss; /* its first field is NodeTag */
|
||||
GradientDescentAlgorithm* algorithm;
|
||||
OptimizerGD* optimizer;
|
||||
ShuffleGD* shuffle;
|
||||
|
||||
// tuple description
|
||||
TupleDesc tupdesc;
|
||||
int n_features; // number of features
|
||||
|
||||
// dependant var binary values
|
||||
int num_classes;
|
||||
Datum binary_classes[2];
|
||||
|
||||
// training state
|
||||
bool done; // when finished
|
||||
Matrix weights;
|
||||
double learning_rate;
|
||||
int n_iterations;
|
||||
int usecs; // execution time
|
||||
int processed; // tuples
|
||||
int discarded;
|
||||
float loss;
|
||||
Scores scores;
|
||||
} GradientDescentState;
|
||||
|
||||
|
||||
typedef struct GradientDescentExprState {
|
||||
ExprState xprstate;
|
||||
PlanState* ps;
|
||||
GradientDescentExpr* xpr;
|
||||
} GradientDescentExprState;
|
||||
|
||||
/*
|
||||
* DB4AI k-means node
|
||||
*/
|
||||
|
||||
/*
|
||||
* to keep running statistics on each cluster being constructed
|
||||
*/
|
||||
class IncrementalStatistics {
|
||||
uint64_t population = 0;
|
||||
double max_value = 0.;
|
||||
double min_value = DBL_MAX;
|
||||
double total = 0.;
|
||||
double s = 0;
|
||||
|
||||
public:
|
||||
|
||||
IncrementalStatistics operator+(IncrementalStatistics const& rhs) const;
|
||||
IncrementalStatistics operator-(IncrementalStatistics const& rhs) const;
|
||||
IncrementalStatistics& operator+=(IncrementalStatistics const& rhs);
|
||||
IncrementalStatistics& operator-=(IncrementalStatistics const& rhs);
|
||||
|
||||
double getMin() const;
|
||||
void setMin(double);
|
||||
double getMax() const;
|
||||
void setMax(double);
|
||||
double getTotal() const;
|
||||
void setTotal(double);
|
||||
uint64_t getPopulation() const;
|
||||
void setPopulation(uint64_t);
|
||||
double getEmpiricalMean() const;
|
||||
double getEmpiricalVariance() const;
|
||||
double getEmpiricalStdDev() const;
|
||||
bool reset();
|
||||
};
|
||||
|
||||
/*
|
||||
* internal representation of a centroid
|
||||
*/
|
||||
typedef struct Centroid {
|
||||
IncrementalStatistics statistics;
|
||||
ArrayType* coordinates = nullptr;
|
||||
uint32_t id = 0U;
|
||||
} Centroid;
|
||||
|
||||
/*
|
||||
* current state of the algorithm
|
||||
*/
|
||||
typedef struct KMeansStateDescription {
|
||||
Centroid* centroids[2] = {nullptr};
|
||||
ArrayType* bbox_min = nullptr;
|
||||
ArrayType* bbox_max = nullptr;
|
||||
|
||||
double (* distance)(double const*, double const*, uint32_t const dimension) = nullptr;
|
||||
|
||||
double execution_time = 0.;
|
||||
double seeding_time = 0.;
|
||||
IncrementalStatistics solution_statistics[2];
|
||||
uint64_t num_good_points = 0UL;
|
||||
uint64_t num_dead_points = 0UL;
|
||||
uint32_t current_iteration = 0U;
|
||||
uint32_t current_centroid = 0U;
|
||||
uint32_t dimension = 0U;
|
||||
uint32_t num_centroids = 0U;
|
||||
uint32_t actual_num_iterations = 0U;
|
||||
} KMeansStateDescription;
|
||||
|
||||
/*
|
||||
* current state of the k-means node
|
||||
*/
|
||||
typedef struct KMeansState {
|
||||
ScanState sst;
|
||||
KMeansStateDescription description;
|
||||
bool done = false;
|
||||
} KMeansState;
|
||||
|
||||
/*
|
||||
* internal representation of a point (not a centroid)
|
||||
*/
|
||||
typedef struct GSPoint {
|
||||
ArrayType* pg_coordinates = nullptr;
|
||||
uint32_t weight = 1U;
|
||||
uint32_t id = 0ULL;
|
||||
double distance_to_closest_centroid = DBL_MAX;
|
||||
bool should_free = false;
|
||||
} GSPoint;
|
||||
|
||||
#endif /* EXECNODES_H */
|
||||
|
||||
|
|
|
@ -725,7 +725,17 @@ typedef enum NodeTag {
|
|||
T_ClientLogicColumnParam,
|
||||
T_CreateClientLogicColumn,
|
||||
T_ClientLogicColumnRef,
|
||||
T_ExprWithComma
|
||||
T_ExprWithComma,
|
||||
// DB4AI
|
||||
T_CreateModelStmt = 5000,
|
||||
T_PredictByFunction,
|
||||
T_GradientDescent,
|
||||
T_GradientDescentState,
|
||||
T_GradientDescentExpr,
|
||||
T_GradientDescentExprState,
|
||||
T_KMeans,
|
||||
T_KMeansState,
|
||||
// End DB4AI
|
||||
} NodeTag;
|
||||
|
||||
/* if you add to NodeTag also need to add nodeTagToString */
|
||||
|
|
|
@ -55,6 +55,7 @@ typedef enum ObjectType {
|
|||
OBJECT_CONVERSION,
|
||||
OBJECT_DATABASE,
|
||||
OBJECT_DATA_SOURCE,
|
||||
OBJECT_DB4AI_MODEL, // DB4AI
|
||||
OBJECT_DOMAIN,
|
||||
OBJECT_EXTENSION,
|
||||
OBJECT_FDW,
|
||||
|
@ -1776,4 +1777,35 @@ typedef struct RenameStmt {
|
|||
bool missing_ok; /* skip error if missing? */
|
||||
} RenameStmt;
|
||||
|
||||
|
||||
|
||||
/* ----------------------
|
||||
* Create Model Statement
|
||||
* ----------------------
|
||||
*/
|
||||
typedef struct CreateModelStmt{ // DB4AI
|
||||
NodeTag type;
|
||||
char* model;
|
||||
char* architecture;
|
||||
List* hyperparameters; // List<VariableSetStmt>
|
||||
Node* select_query; // Query to be executed: SelectStmt -> Query
|
||||
List* model_features; // FEATURES clause
|
||||
List* model_target; // TARGET clause
|
||||
// Filled during transform
|
||||
AlgorithmML algorithm; // Algorithm to be executed
|
||||
} CreateModelStmt;
|
||||
|
||||
/* ----------------------
|
||||
* Prediction BY function
|
||||
* ----------------------
|
||||
*/
|
||||
typedef struct PredictByFunction{ // DB4AI
|
||||
NodeTag type;
|
||||
char* model_name;
|
||||
int model_name_location; // Only for parser
|
||||
List* model_args;
|
||||
int model_args_location; // Only for parser
|
||||
} PredictByFunction;
|
||||
|
||||
|
||||
#endif /* PARSENODES_COMMONH */
|
||||
|
|
|
@ -1399,5 +1399,111 @@ static inline bool IsJoinPlan(Node* node)
|
|||
IsA(node, VecMergeJoin) || IsA(node, HashJoin) || IsA(node, VecHashJoin);
|
||||
}
|
||||
|
||||
/*
|
||||
* DB4AI
|
||||
*/
|
||||
|
||||
|
||||
// GD optimizers
|
||||
typedef enum {
|
||||
OPTIMIZER_GD, // simple mini-batch
|
||||
OPTIMIZER_NGD, // normalized gradient descent
|
||||
} OptimizerML;
|
||||
|
||||
inline void optimizer_ml_setter(const char* str, void* optimizer_ml){
|
||||
OptimizerML* optimizer = (OptimizerML*) optimizer_ml;
|
||||
if (strcmp(str, "gd") == 0)
|
||||
*optimizer = OPTIMIZER_GD;
|
||||
else if (strcmp(str, "ngd") == 0)
|
||||
*optimizer = OPTIMIZER_NGD;
|
||||
else
|
||||
elog(ERROR, "Invalid optimizer. Current candidates are: gd (default), ngd");
|
||||
return;
|
||||
}
|
||||
|
||||
// Gradient Descent node
|
||||
typedef struct GradientDescent {
|
||||
Plan plan;
|
||||
AlgorithmML algorithm;
|
||||
int targetcol;
|
||||
|
||||
// generic hyperparameters
|
||||
OptimizerML optimizer; // default GD/mini-batch
|
||||
int max_seconds; // 0 to disable
|
||||
bool verbose;
|
||||
int max_iterations; // maximum number of iterations
|
||||
int batch_size;
|
||||
double learning_rate;
|
||||
double decay; // (0:1], learning rate decay
|
||||
double tolerance; // [0:1], 0 means to run all iterations
|
||||
int seed; // [0:N], random seed
|
||||
|
||||
// for SVM
|
||||
double lambda; // regularization strength
|
||||
} GradientDescent;
|
||||
|
||||
/*
|
||||
* DB4AI k-means
|
||||
*/
|
||||
|
||||
/*
|
||||
* current available distance functions
|
||||
*/
|
||||
typedef enum : uint32_t {
|
||||
KMEANS_L1 = 0U,
|
||||
KMEANS_L2,
|
||||
KMEANS_L2_SQUARED,
|
||||
KMEANS_LINF
|
||||
} DistanceFunction;
|
||||
|
||||
/*
|
||||
* current available seeding method
|
||||
*/
|
||||
typedef enum : uint32_t {
|
||||
KMEANS_RANDOM_SEED = 0U,
|
||||
KMEANS_BB
|
||||
} SeedingFunction;
|
||||
|
||||
/*
|
||||
* Verbosity level
|
||||
*/
|
||||
typedef enum : uint32_t {
|
||||
NO_OUTPUT = 0U,
|
||||
FASTCHECK_OUTPUT,
|
||||
VERBOSE_OUTPUT
|
||||
} Verbosity;
|
||||
|
||||
/*
|
||||
* description of the k-means instance
|
||||
*/
|
||||
struct KMeansDescription {
|
||||
char const* model_name = nullptr;
|
||||
SeedingFunction seeding = KMEANS_RANDOM_SEED;
|
||||
DistanceFunction distance = KMEANS_L2_SQUARED;
|
||||
Verbosity verbosity = NO_OUTPUT;
|
||||
uint32_t n_features = 0U;
|
||||
uint32_t batch_size = 0U;
|
||||
};
|
||||
|
||||
/*
|
||||
* current hyper-parameters
|
||||
*/
|
||||
struct KMeansHyperParameters {
|
||||
uint32_t num_centroids = 0U;
|
||||
uint32_t num_iterations = 0U;
|
||||
uint64_t external_seed = 0ULL;
|
||||
double tolerance = 0.00001;
|
||||
};
|
||||
|
||||
/*
|
||||
* the actual k-means operator
|
||||
*/
|
||||
typedef struct KMeans {
|
||||
Plan plan;
|
||||
AlgorithmML algorithm;
|
||||
KMeansDescription description;
|
||||
KMeansHyperParameters parameters;
|
||||
} KMeans;
|
||||
|
||||
#endif /* PLANNODES_H */
|
||||
|
||||
|
|
|
@ -1401,4 +1401,42 @@ typedef struct UpsertExpr {
|
|||
bool partKeyUpsert; /* we allow upsert index key and partition key in B_FORMAT */
|
||||
} UpsertExpr;
|
||||
|
||||
/*
|
||||
* DB4AI
|
||||
*/
|
||||
#define DB4AI_SNAPSHOT_VERSION_DELIMITER 1
|
||||
#define DB4AI_SNAPSHOT_VERSION_SEPARATOR 2
|
||||
|
||||
typedef enum {
|
||||
LOGISTIC_REGRESSION,
|
||||
SVM_CLASSIFICATION,
|
||||
KMEANS,
|
||||
LINEAR_REGRESSION,
|
||||
INVALID_ALGORITHM_ML,
|
||||
} AlgorithmML;
|
||||
|
||||
typedef enum GradientDescentExprField {
|
||||
// generic
|
||||
GD_EXPR_ALGORITHM,
|
||||
GD_EXPR_OPTIMIZER,
|
||||
GD_EXPR_RESULT_TYPE,
|
||||
GD_EXPR_NUM_ITERATIONS,
|
||||
GD_EXPR_EXEC_TIME_MSECS,
|
||||
GD_EXPR_PROCESSED_TUPLES,
|
||||
GD_EXPR_DISCARDED_TUPLES,
|
||||
GD_EXPR_WEIGHTS,
|
||||
GD_EXPR_CATEGORIES,
|
||||
// scores
|
||||
GD_EXPR_SCORE = 0x10000, // or-ed with the score id
|
||||
} GradientDescentExprField;
|
||||
|
||||
#define makeGradientDescentExprFieldScore(_SCORE) (GradientDescentExprField)((int)GD_EXPR_SCORE | _SCORE)
|
||||
|
||||
typedef struct GradientDescentExpr {
|
||||
Expr xpr;
|
||||
GradientDescentExprField field;
|
||||
Oid fieldtype; /* pg_type OID of the datatype */
|
||||
} GradientDescentExpr;
|
||||
|
||||
|
||||
#endif /* PRIMNODES_H */
|
||||
|
|
|
@ -226,6 +226,7 @@ PG_KEYWORD("extract", EXTRACT, COL_NAME_KEYWORD)
|
|||
PG_KEYWORD("false", FALSE_P, RESERVED_KEYWORD)
|
||||
PG_KEYWORD("family", FAMILY, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("fast", FAST, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("features", FEATURES, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("fenced", FENCED, RESERVED_KEYWORD)
|
||||
PG_KEYWORD("fetch", FETCH, RESERVED_KEYWORD)
|
||||
PG_KEYWORD("fileheader", FILEHEADER_P, UNRESERVED_KEYWORD)
|
||||
|
@ -348,6 +349,7 @@ PG_KEYWORD("minus", MINUS_P, RESERVED_KEYWORD)
|
|||
PG_KEYWORD("minute", MINUTE_P, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("minvalue", MINVALUE, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("mode", MODE, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("model", MODEL, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("modify", MODIFY_P, RESERVED_KEYWORD)
|
||||
PG_KEYWORD("month", MONTH_P, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("move", MOVE, UNRESERVED_KEYWORD)
|
||||
|
@ -422,6 +424,7 @@ PG_KEYWORD("pool", POOL, UNRESERVED_KEYWORD)
|
|||
PG_KEYWORD("position", POSITION, COL_NAME_KEYWORD)
|
||||
PG_KEYWORD("preceding", PRECEDING, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("precision", PRECISION, COL_NAME_KEYWORD)
|
||||
PG_KEYWORD("predict", PREDICT, UNRESERVED_KEYWORD)
|
||||
/* PGXC_BEGIN */
|
||||
PG_KEYWORD("preferred", PREFERRED, UNRESERVED_KEYWORD)
|
||||
/* PGXC_END */
|
||||
|
@ -441,6 +444,7 @@ PG_KEYWORD("query", QUERY, UNRESERVED_KEYWORD)
|
|||
PG_KEYWORD("quote", QUOTE, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("randomized", RANDOMIZED, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("range", RANGE, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("ratio", RATIO, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("raw", RAW, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("read", READ, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("real", REAL, COL_NAME_KEYWORD)
|
||||
|
@ -484,6 +488,7 @@ PG_KEYWORD("rownum", ROWNUM, RESERVED_KEYWORD)
|
|||
#endif
|
||||
PG_KEYWORD("rows", ROWS, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("rule", RULE, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("sample", SAMPLE, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("savepoint", SAVEPOINT, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("schema", SCHEMA, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("scroll", SCROLL, UNRESERVED_KEYWORD)
|
||||
|
@ -542,6 +547,7 @@ PG_KEYWORD("table", TABLE, RESERVED_KEYWORD)
|
|||
PG_KEYWORD("tables", TABLES, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("tablesample", TABLESAMPLE, TYPE_FUNC_NAME_KEYWORD)
|
||||
PG_KEYWORD("tablespace", TABLESPACE, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("target", TARGET, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("temp", TEMP, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("template", TEMPLATE, UNRESERVED_KEYWORD)
|
||||
PG_KEYWORD("temporary", TEMPORARY, UNRESERVED_KEYWORD)
|
||||
|
|
|
@ -108,6 +108,8 @@ typedef enum {
|
|||
DestBatchLocalRedistribute, /* results send to consumer thread in a local redistribute way */
|
||||
DestBatchLocalRoundRobin, /* results send to consumer thread in a local roundrobin way */
|
||||
|
||||
DestTrainModel, /* results send to DB4AI model warehouse */
|
||||
|
||||
DestBatchHybrid,
|
||||
DestTransientRel /* results sent to transient relation */
|
||||
|
||||
|
|
|
@ -102,6 +102,7 @@ enum ModuleId {
|
|||
MOD_THREAD_POOL, /* thread_pool */
|
||||
MOD_OPT_AI, /* ai optimizer */
|
||||
MOD_GEN_COL, /* generated column */
|
||||
MOD_DB4AI, /* DB4AI & AUTOML */
|
||||
|
||||
/* add your module id above */
|
||||
MOD_MAX
|
||||
|
|
|
@ -60,6 +60,8 @@ enum SysCacheIdentifier {
|
|||
DATABASEOID,
|
||||
DATASOURCENAME,
|
||||
DATASOURCEOID,
|
||||
DB4AI_MODELOID,
|
||||
DB4AI_MODEL,
|
||||
DEFACLROLENSPOBJ,
|
||||
DIRECTORYNAME,
|
||||
DIRECTORYOID,
|
||||
|
|
|
@ -266,6 +266,9 @@ fastcheck_single: all tablespace-setup
|
|||
$(call hotpatch_setup_func) && \
|
||||
$(pg_regress_check) $(REGRESS_OPTS) -d 1 -c 0 -p $(p) -r $(runtest) -b $(dir) -n $(n) --abs_gausshome=$(abs_gausshome) --single_node --schedule=$(srcdir)/parallel_schedule0$(PART) -w --keep_last_data=${keep_last_data} $(MAXCONNOPT) --temp-config=$(srcdir)/make_fastcheck_postgresql.conf $(EXTRA_TESTS) $(REG_CONF)
|
||||
|
||||
fastcheck_single_db4ai: all tablespace-setup
|
||||
export LD_LIBRARY_PATH=$(SSL_LIB_PATH):$(LD_LIBRARY_PATH) && \
|
||||
$(pg_regress_check) $(REGRESS_OPTS) -d 1 -c 0 -p $(p) -r $(runtest) -b $(dir) -n $(n) --abs_gausshome=$(abs_gausshome) --single_node --schedule=$(srcdir)/parallel_schedule.db4ai -w --keep_last_data=${keep_last_data} $(MAXCONNOPT) --temp-config=$(srcdir)/make_fastcheck_postgresql.conf $(EXTRA_TESTS) $(REG_CONF)
|
||||
fastcheck_single_mot: all tablespace-setup
|
||||
export LD_LIBRARY_PATH=$(SSL_LIB_PATH):$(LD_LIBRARY_PATH) && \
|
||||
$(pg_regress_check) $(REGRESS_OPTS) -d 1 -c 0 -p $(p) -r $(runtest) -b $(dir) -n $(n) --abs_gausshome=$(abs_gausshome) --single_node --schedule=$(srcdir)/parallel_schedule20 -w --keep_last_data=${keep_last_data} $(MAXCONNOPT) --temp-config=$(srcdir)/make_fastcheck_single_mot_postgresql.conf $(EXTRA_TESTS) $(REG_CONF)
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
3151, 1,0,0, 0.655000000000000027, 0.505000000000000004, 0.165000000000000008, 1.36699999999999999, 0.583500000000000019, 0.351499999999999979, 0.396000000000000019, 10
|
||||
2026, 1,0,0, 0.550000000000000044, 0.469999999999999973, 0.149999999999999994, 0.920499999999999985, 0.381000000000000005, 0.243499999999999994, 0.267500000000000016, 10
|
||||
3751, 0,1,0, 0.434999999999999998, 0.375, 0.110000000000000001, 0.41549999999999998, 0.170000000000000012, 0.0759999999999999981, 0.14499999999999999, 8
|
||||
720, 0,1,0, 0.149999999999999994, 0.100000000000000006, 0.0250000000000000014, 0.0149999999999999994, 0.00449999999999999966, 0.00400000000000000008, 0.0050000000000000001, 2
|
||||
1635, 1,0,0, 0.574999999999999956, 0.469999999999999973, 0.154999999999999999, 1.1160000000000001, 0.509000000000000008, 0.237999999999999989, 0.340000000000000024, 10
|
||||
2648, 0,1,0, 0.5, 0.390000000000000013, 0.125, 0.582999999999999963, 0.293999999999999984, 0.132000000000000006, 0.160500000000000004, 8
|
||||
1796, 1,0,0, 0.57999999999999996, 0.429999999999999993, 0.170000000000000012, 1.47999999999999998, 0.65349999999999997, 0.32400000000000001, 0.41549999999999998, 10
|
||||
209, 1,0,0, 0.525000000000000022, 0.41499999999999998, 0.170000000000000012, 0.832500000000000018, 0.275500000000000023, 0.168500000000000011, 0.309999999999999998, 13
|
||||
1451, 0,1,0, 0.455000000000000016, 0.33500000000000002, 0.135000000000000009, 0.501000000000000001, 0.274000000000000021, 0.0995000000000000051, 0.106499999999999997, 7
|
||||
1108, 0,1,0, 0.510000000000000009, 0.380000000000000004, 0.115000000000000005, 0.515499999999999958, 0.214999999999999997, 0.113500000000000004, 0.166000000000000009, 8
|
||||
3675, 1,0,0, 0.594999999999999973, 0.450000000000000011, 0.165000000000000008, 1.08099999999999996, 0.489999999999999991, 0.252500000000000002, 0.279000000000000026, 12
|
||||
2108, 1,0,0, 0.675000000000000044, 0.550000000000000044, 0.179999999999999993, 1.68849999999999989, 0.562000000000000055, 0.370499999999999996, 0.599999999999999978, 15
|
||||
3312, 1,0,0, 0.479999999999999982, 0.380000000000000004, 0.135000000000000009, 0.507000000000000006, 0.191500000000000004, 0.13650000000000001, 0.154999999999999999, 12
|
||||
882, 0,0,1, 0.655000000000000027, 0.520000000000000018, 0.165000000000000008, 1.40949999999999998, 0.585999999999999965, 0.290999999999999981, 0.405000000000000027, 9
|
||||
3402, 0,0,1, 0.479999999999999982, 0.395000000000000018, 0.149999999999999994, 0.681499999999999995, 0.214499999999999996, 0.140500000000000014, 0.2495, 18
|
||||
829, 0,1,0, 0.409999999999999976, 0.325000000000000011, 0.100000000000000006, 0.394000000000000017, 0.20799999999999999, 0.0655000000000000027, 0.105999999999999997, 6
|
||||
1305, 0,0,1, 0.535000000000000031, 0.434999999999999998, 0.149999999999999994, 0.716999999999999971, 0.347499999999999976, 0.14449999999999999, 0.194000000000000006, 9
|
||||
3613, 0,0,1, 0.599999999999999978, 0.46000000000000002, 0.179999999999999993, 1.1399999999999999, 0.422999999999999987, 0.257500000000000007, 0.364999999999999991, 10
|
||||
1068, 0,1,0, 0.340000000000000024, 0.265000000000000013, 0.0800000000000000017, 0.201500000000000012, 0.0899999999999999967, 0.0475000000000000006, 0.0550000000000000003, 5
|
||||
2446, 0,0,1, 0.5, 0.380000000000000004, 0.135000000000000009, 0.583500000000000019, 0.22950000000000001, 0.126500000000000001, 0.179999999999999993, 12
|
||||
1393, 0,0,1, 0.635000000000000009, 0.474999999999999978, 0.170000000000000012, 1.19350000000000001, 0.520499999999999963, 0.269500000000000017, 0.366499999999999992, 10
|
||||
359, 0,0,1, 0.744999999999999996, 0.584999999999999964, 0.214999999999999997, 2.49900000000000011, 0.92649999999999999, 0.471999999999999975, 0.699999999999999956, 17
|
||||
549, 1,0,0, 0.564999999999999947, 0.450000000000000011, 0.160000000000000003, 0.79500000000000004, 0.360499999999999987, 0.155499999999999999, 0.23000000000000001, 12
|
||||
1154, 1,0,0, 0.599999999999999978, 0.474999999999999978, 0.160000000000000003, 1.02649999999999997, 0.484999999999999987, 0.2495, 0.256500000000000006, 9
|
||||
1790, 1,0,0, 0.54500000000000004, 0.385000000000000009, 0.149999999999999994, 1.11850000000000005, 0.542499999999999982, 0.244499999999999995, 0.284499999999999975, 9
|
||||
3703, 1,0,0, 0.665000000000000036, 0.540000000000000036, 0.195000000000000007, 1.76400000000000001, 0.850500000000000034, 0.361499999999999988, 0.469999999999999973, 11
|
||||
1962, 1,0,0, 0.655000000000000027, 0.515000000000000013, 0.179999999999999993, 1.41199999999999992, 0.619500000000000051, 0.248499999999999999, 0.496999999999999997, 11
|
||||
1665, 0,1,0, 0.604999999999999982, 0.469999999999999973, 0.14499999999999999, 0.802499999999999991, 0.379000000000000004, 0.226500000000000007, 0.220000000000000001, 9
|
||||
635, 0,0,1, 0.359999999999999987, 0.294999999999999984, 0.100000000000000006, 0.210499999999999993, 0.0660000000000000031, 0.0524999999999999981, 0.0749999999999999972, 9
|
||||
3901, 0,0,1, 0.445000000000000007, 0.344999999999999973, 0.140000000000000013, 0.475999999999999979, 0.205499999999999988, 0.101500000000000007, 0.108499999999999999, 15
|
||||
2734, 0,1,0, 0.41499999999999998, 0.33500000000000002, 0.100000000000000006, 0.357999999999999985, 0.169000000000000011, 0.067000000000000004, 0.104999999999999996, 7
|
||||
3856, 0,0,1, 0.409999999999999976, 0.33500000000000002, 0.115000000000000005, 0.440500000000000003, 0.190000000000000002, 0.0850000000000000061, 0.135000000000000009, 8
|
||||
827, 0,1,0, 0.395000000000000018, 0.28999999999999998, 0.0950000000000000011, 0.303999999999999992, 0.127000000000000002, 0.0840000000000000052, 0.076999999999999999, 6
|
||||
3381, 0,1,0, 0.190000000000000002, 0.130000000000000004, 0.0449999999999999983, 0.0264999999999999993, 0.00899999999999999932, 0.0050000000000000001, 0.00899999999999999932, 5
|
||||
3972, 0,1,0, 0.400000000000000022, 0.294999999999999984, 0.0950000000000000011, 0.252000000000000002, 0.110500000000000001, 0.0575000000000000025, 0.0660000000000000031, 6
|
||||
1155, 0,0,1, 0.599999999999999978, 0.455000000000000016, 0.170000000000000012, 1.1915, 0.695999999999999952, 0.239499999999999991, 0.239999999999999991, 8
|
||||
3467, 0,0,1, 0.640000000000000013, 0.5, 0.170000000000000012, 1.4544999999999999, 0.642000000000000015, 0.357499999999999984, 0.353999999999999981, 9
|
||||
2433, 1,0,0, 0.609999999999999987, 0.484999999999999987, 0.165000000000000008, 1.08699999999999997, 0.425499999999999989, 0.232000000000000012, 0.380000000000000004, 11
|
||||
552, 0,1,0, 0.614999999999999991, 0.489999999999999991, 0.154999999999999999, 0.988500000000000045, 0.41449999999999998, 0.195000000000000007, 0.344999999999999973, 13
|
||||
1425, 1,0,0, 0.729999999999999982, 0.57999999999999996, 0.190000000000000002, 1.73750000000000004, 0.678499999999999992, 0.434499999999999997, 0.520000000000000018, 11
|
||||
2402, 1,0,0, 0.584999999999999964, 0.41499999999999998, 0.154999999999999999, 0.69850000000000001, 0.299999999999999989, 0.145999999999999991, 0.195000000000000007, 12
|
||||
1748, 1,0,0, 0.699999999999999956, 0.535000000000000031, 0.174999999999999989, 1.77299999999999991, 0.680499999999999994, 0.479999999999999982, 0.512000000000000011, 15
|
||||
3983, 0,1,0, 0.57999999999999996, 0.434999999999999998, 0.149999999999999994, 0.891499999999999959, 0.362999999999999989, 0.192500000000000004, 0.251500000000000001, 6
|
||||
335, 1,0,0, 0.739999999999999991, 0.599999999999999978, 0.195000000000000007, 1.97399999999999998, 0.597999999999999976, 0.408499999999999974, 0.709999999999999964, 16
|
||||
1587, 0,1,0, 0.515000000000000013, 0.349999999999999978, 0.104999999999999996, 0.474499999999999977, 0.212999999999999995, 0.122999999999999998, 0.127500000000000002, 10
|
||||
2448, 0,1,0, 0.275000000000000022, 0.204999999999999988, 0.0800000000000000017, 0.096000000000000002, 0.0359999999999999973, 0.0184999999999999991, 0.0299999999999999989, 6
|
||||
1362, 1,0,0, 0.604999999999999982, 0.474999999999999978, 0.174999999999999989, 1.07600000000000007, 0.463000000000000023, 0.219500000000000001, 0.33500000000000002, 9
|
||||
2799, 0,0,1, 0.640000000000000013, 0.484999999999999987, 0.149999999999999994, 1.09800000000000009, 0.519499999999999962, 0.222000000000000003, 0.317500000000000004, 10
|
||||
1413, 1,0,0, 0.67000000000000004, 0.505000000000000004, 0.174999999999999989, 1.01449999999999996, 0.4375, 0.271000000000000019, 0.3745, 10
|
||||
1739, 1,0,0, 0.67000000000000004, 0.540000000000000036, 0.195000000000000007, 1.61899999999999999, 0.739999999999999991, 0.330500000000000016, 0.465000000000000024, 11
|
||||
1152, 0,0,1, 0.584999999999999964, 0.465000000000000024, 0.160000000000000003, 0.955500000000000016, 0.45950000000000002, 0.235999999999999988, 0.265000000000000013, 7
|
||||
2427, 0,1,0, 0.564999999999999947, 0.434999999999999998, 0.154999999999999999, 0.782000000000000028, 0.271500000000000019, 0.16800000000000001, 0.284999999999999976, 14
|
||||
1777, 0,0,1, 0.484999999999999987, 0.369999999999999996, 0.154999999999999999, 0.967999999999999972, 0.418999999999999984, 0.245499999999999996, 0.236499999999999988, 9
|
||||
3294, 0,0,1, 0.574999999999999956, 0.455000000000000016, 0.184999999999999998, 1.15599999999999992, 0.552499999999999991, 0.242999999999999994, 0.294999999999999984, 13
|
||||
1403, 0,0,1, 0.650000000000000022, 0.510000000000000009, 0.190000000000000002, 1.54200000000000004, 0.715500000000000025, 0.373499999999999999, 0.375, 9
|
||||
2256, 0,0,1, 0.510000000000000009, 0.395000000000000018, 0.14499999999999999, 0.61850000000000005, 0.215999999999999998, 0.138500000000000012, 0.239999999999999991, 12
|
||||
3984, 1,0,0, 0.584999999999999964, 0.450000000000000011, 0.125, 0.873999999999999999, 0.354499999999999982, 0.20749999999999999, 0.225000000000000006, 6
|
||||
1116, 0,0,1, 0.525000000000000022, 0.405000000000000027, 0.119999999999999996, 0.755499999999999949, 0.3755, 0.155499999999999999, 0.201000000000000012, 9
|
||||
1366, 0,0,1, 0.609999999999999987, 0.474999999999999978, 0.170000000000000012, 1.02649999999999997, 0.434999999999999998, 0.233500000000000013, 0.303499999999999992, 10
|
||||
3759, 0,1,0, 0.525000000000000022, 0.400000000000000022, 0.140000000000000013, 0.605500000000000038, 0.260500000000000009, 0.107999999999999999, 0.209999999999999992, 9
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue