Merge pull request #2999 from tclinken/more-move-optimizations

Generate rvalue reference overloads for actor callback functions
This commit is contained in:
Evan Tschannen 2020-05-07 14:32:27 -07:00 committed by GitHub
commit 574914640a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 121 additions and 28 deletions

View File

@ -80,6 +80,7 @@ class LambdaCallback : public CallbackType, public FastAllocated<LambdaCallback<
ErrFunc errFunc;
virtual void fire(T const& t) { CallbackType::remove(); func(t); delete this; }
virtual void fire(T && t) { CallbackType::remove(); func(std::move(t)); delete this; }
virtual void error(Error e) { CallbackType::remove(); errFunc(e); delete this; }
public:
@ -1384,21 +1385,23 @@ struct Tracker {
this->copied = other.copied + 1;
return *this;
}
~Tracker() = default;
ACTOR static Future<Void> listen(FutureStream<Tracker> stream) {
Tracker t = waitNext(stream);
ASSERT(!t.moved);
ASSERT(t.copied == 0);
Tracker movedTracker = waitNext(stream);
ASSERT(!movedTracker.moved);
ASSERT(movedTracker.copied == 0);
return Void();
}
};
TEST_CASE("/flow/flow/PromiseStream/move") {
state PromiseStream<Tracker> stream;
state Future<Void> listener;
{
// This tests the case when a callback is added before
// a movable value is sent
state Future<Void> listener = Tracker::listen(stream.getFuture());
listener = Tracker::listen(stream.getFuture());
stream.send(Tracker{});
wait(listener);
}
@ -1417,15 +1420,14 @@ TEST_CASE("/flow/flow/PromiseStream/move") {
stream.send(Tracker{});
stream.send(Tracker{});
{
Tracker t = waitNext(stream.getFuture());
ASSERT(!t.moved);
ASSERT(t.copied == 0);
state Tracker movedTracker = waitNext(stream.getFuture());
ASSERT(!movedTracker.moved);
ASSERT(movedTracker.copied == 0);
}
choose {
when(Tracker t = waitNext(stream.getFuture())) {
ASSERT(!t.moved);
ASSERT(t.copied == 0);
}
{
Tracker movedTracker = waitNext(stream.getFuture());
ASSERT(!movedTracker.moved);
ASSERT(movedTracker.copied == 0);
}
}
{
@ -1436,19 +1438,29 @@ TEST_CASE("/flow/flow/PromiseStream/move") {
stream.send(namedTracker1);
stream.send(namedTracker2);
{
Tracker t = waitNext(stream.getFuture());
ASSERT(!t.moved);
state Tracker copiedTracker = waitNext(stream.getFuture());
ASSERT(!copiedTracker.moved);
// must copy onto queue
ASSERT(t.copied == 1);
ASSERT(copiedTracker.copied == 1);
}
choose {
when(Tracker t = waitNext(stream.getFuture())) {
ASSERT(!t.moved);
// must copy onto queue
ASSERT(t.copied == 1);
}
{
Tracker copiedTracker = waitNext(stream.getFuture());
ASSERT(!copiedTracker.moved);
// must copy onto queue
ASSERT(copiedTracker.copied == 1);
}
}
return Void();
}
TEST_CASE("/flow/flow/PromiseStream/move2") {
PromiseStream<Tracker> stream;
stream.send(Tracker{});
Tracker tracker = waitNext(stream.getFuture());
Tracker movedTracker = std::move(tracker);
ASSERT(tracker.moved);
ASSERT(!movedTracker.moved);
ASSERT(movedTracker.copied == 0);
return Void();
}

View File

@ -86,31 +86,65 @@ namespace actorcompiler
string indentation;
StreamWriter body;
public bool wasCalled { get; protected set; }
public Function overload = null;
public Function()
{
body = new StreamWriter(new MemoryStream());
}
public void setOverload(Function overload) {
this.overload = overload;
}
public Function popOverload() {
Function result = this.overload;
this.overload = null;
return result;
}
public void addOverload(params string[] formalParameters) {
setOverload(
new Function {
name = name,
returnType = returnType,
endIsUnreachable = endIsUnreachable,
formalParameters = formalParameters
}
);
}
public void Indent(int change)
{
for(int i=0; i<change; i++) indentation += '\t';
if (change < 0) indentation = indentation.Substring(-change);
if (overload != null) {
overload.Indent(change);
}
}
public void WriteLineUnindented(string s)
{
body.WriteLine(s);
if (overload != null) {
overload.WriteLineUnindented(s);
}
}
public void WriteLine(string line)
{
body.Write(indentation);
body.WriteLine(line);
if (overload != null) {
overload.WriteLine(line);
}
}
public void WriteLine(string line, params object[] args)
{
body.Write(indentation);
body.WriteLine(line, args);
if (overload != null) {
overload.WriteLine(line, args);
}
}
public string BodyText
@ -754,8 +788,9 @@ namespace actorcompiler
Group = group,
Index = this.whenCount+i,
Body = getFunction(cx.target.name, "when",
string.Format("{0} const& {2}{1}", ch.wait.result.type, ch.wait.result.name, ch.wait.resultIsState?"__":""),
loopDepth),
new string[] { string.Format("{0} const& {2}{1}", ch.wait.result.type, ch.wait.result.name, ch.wait.resultIsState?"__":""), loopDepth },
new string[] { string.Format("{0} && {2}{1}", ch.wait.result.type, ch.wait.result.name, ch.wait.resultIsState?"__":""), loopDepth }
),
Future = string.Format("__when_expr_{0}", this.whenCount + i),
CallbackType = string.Format("{3}< {0}, {1}, {2} >", fullClassName, this.whenCount + i, ch.wait.result.type, ch.wait.isWaitNext ? "ActorSingleCallback" : "ActorCallback"),
CallbackTypeInStateClass = string.Format("{3}< {0}, {1}, {2} >", className, this.whenCount + i, ch.wait.result.type, ch.wait.isWaitNext ? "ActorSingleCallback" : "ActorCallback")
@ -784,6 +819,7 @@ namespace actorcompiler
var r = ch.Body;
if (ch.Stmt.wait.resultIsState)
{
Function overload = r.popOverload();
CompileStatement(new StateDeclarationStatement
{
FirstSourceLine = ch.Stmt.FirstSourceLine,
@ -794,6 +830,11 @@ namespace actorcompiler
initializerConstructorSyntax = false
}
}, cx.WithTarget(r));
if (overload != null)
{
overload.WriteLine("{0} = std::move(__{0});", ch.Stmt.wait.result.name);
r.setOverload(overload);
}
}
if (ch.Stmt.body != null)
{
@ -804,8 +845,14 @@ namespace actorcompiler
reachable = true;
if (cx.next.formalParameters.Length == 1)
r.WriteLine("loopDepth = {0};", cx.next.call("loopDepth"));
else
else {
Function overload = r.popOverload();
r.WriteLine("loopDepth = {0};", cx.next.call(ch.Stmt.wait.result.name, "loopDepth"));
if (overload != null) {
overload.WriteLine("loopDepth = {0};", cx.next.call(string.Format("std::move({0})", ch.Stmt.wait.result.name), "loopDepth"));
r.setOverload(overload);
}
}
}
var cbFunc = new Function {
@ -817,13 +864,22 @@ namespace actorcompiler
},
endIsUnreachable = true
};
cbFunc.addOverload(ch.CallbackTypeInStateClass + "*", ch.Stmt.wait.result.type + " && value");
functions.Add(string.Format("{0}#{1}", cbFunc.name, ch.Index), cbFunc);
cbFunc.Indent(codeIndent);
ProbeEnter(cbFunc, actor.name, ch.Index);
cbFunc.WriteLine("{0};", exitFunc.call());
Function _overload = cbFunc.popOverload();
TryCatch(cx.WithTarget(cbFunc), cx.catchFErr, cx.tryLoopDepth, () => {
cbFunc.WriteLine("{0};", ch.Body.call("value", "0"));
}, false);
if (_overload != null) {
TryCatch(cx.WithTarget(_overload), cx.catchFErr, cx.tryLoopDepth, () => {
_overload.WriteLine("{0};", ch.Body.call("std::move(value)", "0"));
}, false);
cbFunc.setOverload(_overload);
}
ProbeExit(cbFunc, actor.name, ch.Index);
var errFunc = new Function
@ -916,10 +972,14 @@ namespace actorcompiler
},
FirstSourceLine = stmt.FirstSourceLine
};
if (!stmt.resultIsState)
if (!stmt.resultIsState) {
cx.next.formalParameters = new string[] {
string.Format("{0} const& {1}", stmt.result.type, stmt.result.name),
loopDepth };
cx.next.addOverload(
string.Format("{0} && {1}", stmt.result.type, stmt.result.name),
loopDepth);
}
CompileStatement(equiv, cx);
}
void CompileStatement(CodeBlock stmt, Context cx)
@ -1116,6 +1176,14 @@ namespace actorcompiler
{
WriteFunction(writer, func, body);
}
if (func.overload != null)
{
string overloadBody = func.overload.BodyText;
if (overloadBody.Length != 0)
{
WriteFunction(writer, func.overload, overloadBody);
}
}
}
}
@ -1133,7 +1201,7 @@ namespace actorcompiler
writer.WriteLine(memberIndentStr + "}");
}
Function getFunction(string baseName, string addName, params string[] formalParameters)
Function getFunction(string baseName, string addName, string[] formalParameters, string[] overloadFormalParameters)
{
string proposedName;
if (addName == "cont" && baseName.Length>=5 && baseName.Substring(baseName.Length - 5, 4) == "cont")
@ -1149,10 +1217,19 @@ namespace actorcompiler
returnType = "int",
formalParameters = formalParameters
};
if (overloadFormalParameters != null) {
f.addOverload(overloadFormalParameters);
}
f.Indent(codeIndent);
functions.Add(f.name, f);
return f;
}
Function getFunction(string baseName, string addName, params string[] formalParameters)
{
return getFunction(baseName, addName, formalParameters, null);
}
string[] ParameterList()
{
return actor.parameters.Select(p =>

View File

@ -391,6 +391,7 @@ struct SingleCallback {
SingleCallback<T> *next;
virtual void fire(T const&) {}
virtual void fire(T &&) {}
virtual void error(Error) {}
virtual void unwait() {}
@ -1014,10 +1015,13 @@ struct ActorCallback : Callback<ValueType> {
template <class ActorType, int CallbackNumber, class ValueType>
struct ActorSingleCallback : SingleCallback<ValueType> {
virtual void fire(ValueType const& value) {
virtual void fire(ValueType const& value) override {
static_cast<ActorType*>(this)->a_callback_fire(this, value);
}
virtual void error(Error e) {
virtual void fire(ValueType && value) override {
static_cast<ActorType*>(this)->a_callback_fire(this, std::move(value));
}
virtual void error(Error e) override {
static_cast<ActorType*>(this)->a_callback_error(this, e);
}
};