Prevent shardTracker or trackShardBytes from accidentally unsafely accessing DataDistributionTracker

This commit is contained in:
sfc-gh-tclinkenbeard 2020-11-16 12:46:21 -08:00
parent ca8ea3b6ff
commit 6235d087a6
1 changed files with 38 additions and 45 deletions

View File

@ -81,6 +81,28 @@ struct DataDistributionTracker {
// be accessed
bool const& trackerCancelled;
// This class extracts the trackerCancelled reference from a DataDistributionTracker object
// Because some actors spawned by the dataDistributionTracker outlive the DataDistributionTracker
// object, we must guard against memory errors by using a GetTracker functor to access
// the DataDistributionTracker object.
class SafeAccessor {
bool const& trackerCancelled;
DataDistributionTracker& tracker;
public:
SafeAccessor(DataDistributionTracker* tracker)
: trackerCancelled(tracker->trackerCancelled), tracker(*tracker) {
ASSERT(!trackerCancelled);
}
DataDistributionTracker* operator()() {
if (trackerCancelled) {
throw dd_tracker_cancelled();
}
return &tracker;
}
};
DataDistributionTracker(Database cx, UID distributorId, Promise<Void> const& readyToStart,
PromiseStream<RelocateShard> const& output,
Reference<ShardsAffectedByTeamFailure> shardsAffectedByTeamFailure,
@ -140,36 +162,8 @@ int64_t getMaxShardSize( double dbSizeEstimate ) {
(int64_t)SERVER_KNOBS->MAX_SHARD_BYTES);
}
// This class extracts the trackerCancelled reference from a DataDistributionTracker object
// Because some actors spawned by the dataDistributionTracker outlive the DataDistributionTracker
// object, we must guard against memory errors by using a GetTracker functor to access
// the DataDistributionTracker object.
//
// Ideally this would be implemented with a lambda instead, but the actor compiler does not do
// type deduction.
class GetTracker {
bool const& trackerCancelled;
DataDistributionTracker& tracker;
public:
GetTracker(DataDistributionTracker* tracker) : trackerCancelled(tracker->trackerCancelled), tracker(*tracker) {
ASSERT(!trackerCancelled);
}
DataDistributionTracker* operator()() {
if (trackerCancelled) {
throw dd_tracker_cancelled();
}
return &tracker;
}
};
ACTOR Future<Void> trackShardBytes(
DataDistributionTracker* self,
KeyRange keys,
Reference<AsyncVar<Optional<ShardMetrics>>> shardSize)
{
state GetTracker getSelf(self);
ACTOR Future<Void> trackShardBytes(DataDistributionTracker::SafeAccessor self, KeyRange keys,
Reference<AsyncVar<Optional<ShardMetrics>>> shardSize) {
state BandwidthStatus bandwidthStatus = shardSize->get().present() ? getBandwidthStatus( shardSize->get().get().metrics ) : BandwidthStatusNormal;
state double lastLowBandwidthStartTime = shardSize->get().present() ? shardSize->get().get().lastLowBandwidthStartTime : now();
state int shardCount = shardSize->get().present() ? shardSize->get().get().shardCount : 1;
@ -219,7 +213,7 @@ ACTOR Future<Void> trackShardBytes(
bounds.permittedError.iosPerKSecond = bounds.permittedError.infinity;
loop {
Transaction tr(getSelf()->cx);
Transaction tr(self()->cx);
// metrics.second is the number of key-ranges (i.e., shards) in the 'keys' key-range
std::pair<Optional<StorageMetrics>, int> metrics = wait( tr.waitStorageMetrics( keys, bounds.min, bounds.max, bounds.permittedError, CLIENT_KNOBS->STORAGE_METRICS_SHARD_LIMIT, shardCount ) );
if(metrics.first.present()) {
@ -243,10 +237,10 @@ ACTOR Future<Void> trackShardBytes(
.detail("TrackerID", trackerID);*/
if( shardSize->get().present() ) {
getSelf()->dbSizeEstimate->set(getSelf()->dbSizeEstimate->get() + metrics.first.get().bytes -
shardSize->get().get().metrics.bytes);
self()->dbSizeEstimate->set(self()->dbSizeEstimate->get() + metrics.first.get().bytes -
shardSize->get().get().metrics.bytes);
if(keys.begin >= systemKeys.begin) {
getSelf()->systemSizeEstimate +=
self()->systemSizeEstimate +=
metrics.first.get().bytes - shardSize->get().get().metrics.bytes;
}
}
@ -266,7 +260,7 @@ ACTOR Future<Void> trackShardBytes(
} catch( Error &e ) {
if (e.code() != error_code_actor_cancelled && e.code() != error_code_broken_promise &&
e.code() != error_code_dd_tracker_cancelled) {
getSelf()->output.sendError(e); // Propagate failure to dataDistributionTracker
self()->output.sendError(e); // Propagate failure to dataDistributionTracker
}
throw e;
}
@ -630,15 +624,14 @@ ACTOR Future<Void> shardEvaluator(
return Void();
}
ACTOR Future<Void> shardTracker(DataDistributionTracker* self, KeyRange keys,
ACTOR Future<Void> shardTracker(DataDistributionTracker::SafeAccessor self, KeyRange keys,
Reference<AsyncVar<Optional<ShardMetrics>>> shardSize) {
state GetTracker getSelf(self);
wait( yieldedFuture(self->readyToStart.getFuture()) );
wait(yieldedFuture(self()->readyToStart.getFuture()));
if( !shardSize->get().present() )
wait( shardSize->onChange() );
if (!getSelf()->maxShardSize->get().present()) wait(yieldedFuture(getSelf()->maxShardSize->onChange()));
if (!self()->maxShardSize->get().present()) wait(yieldedFuture(self()->maxShardSize->onChange()));
// Since maxShardSize will become present for all shards at once, avoid slow tasks with a short delay
wait( delay( 0, TaskPriority::DataDistribution ) );
@ -646,18 +639,18 @@ ACTOR Future<Void> shardTracker(DataDistributionTracker* self, KeyRange keys,
// Survives multiple calls to shardEvaluator and keeps merges from happening too quickly.
state Reference<HasBeenTrueFor> wantsToMerge( new HasBeenTrueFor( shardSize->get() ) );
/*TraceEvent("ShardTracker", getSelf()->distributorId)
/*TraceEvent("ShardTracker", self()->distributorId)
.detail("Begin", keys.begin)
.detail("End", keys.end)
.detail("TrackerID", trackerID)
.detail("MaxBytes", getSelf()->maxShardSize->get().get())
.detail("MaxBytes", self()->maxShardSize->get().get())
.detail("ShardSize", shardSize->get().get().bytes)
.detail("BytesPerKSec", shardSize->get().get().bytesPerKSecond);*/
try {
loop {
// Use the current known size to check for (and start) splits and merges.
wait(shardEvaluator(getSelf(), keys, shardSize, wantsToMerge));
wait(shardEvaluator(self(), keys, shardSize, wantsToMerge));
// We could have a lot of actors being released from the previous wait at the same time. Immediately calling
// delay(0) mitigates the resulting SlowTask
@ -667,7 +660,7 @@ ACTOR Future<Void> shardTracker(DataDistributionTracker* self, KeyRange keys,
// If e is broken_promise then self may have already been deleted
if (e.code() != error_code_actor_cancelled && e.code() != error_code_broken_promise &&
e.code() != error_code_dd_tracker_cancelled) {
getSelf()->output.sendError(e); // Propagate failure to dataDistributionTracker
self()->output.sendError(e); // Propagate failure to dataDistributionTracker
}
throw e;
}
@ -699,8 +692,8 @@ void restartShardTrackers(DataDistributionTracker* self, KeyRangeRef keys, Optio
ShardTrackedData data;
data.stats = shardSize;
data.trackShard = shardTracker( self, ranges[i], shardSize );
data.trackBytes = trackShardBytes( self, ranges[i], shardSize );
data.trackShard = shardTracker(DataDistributionTracker::SafeAccessor(self), ranges[i], shardSize);
data.trackBytes = trackShardBytes(DataDistributionTracker::SafeAccessor(self), ranges[i], shardSize);
self->shards.insert( ranges[i], data );
}
}