code clean
This commit is contained in:
parent
143c305286
commit
a005e9d08b
|
@ -119,8 +119,8 @@ bool PriorityReplayBufferSampleCpuKernel::Launch(const std::vector<AddressPtr> &
|
||||||
for (size_t transition_index = 0; transition_index < samples.size(); transition_index++) {
|
for (size_t transition_index = 0; transition_index < samples.size(); transition_index++) {
|
||||||
const std::vector<AddressPtr> &transition = samples[transition_index];
|
const std::vector<AddressPtr> &transition = samples[transition_index];
|
||||||
for (size_t item_index = 0; item_index < schema_.size(); item_index++) {
|
for (size_t item_index = 0; item_index < schema_.size(); item_index++) {
|
||||||
void *offset =
|
void *offset = reinterpret_cast<uint8_t *>(outputs[item_index + kTransitionIndex]->addr) +
|
||||||
reinterpret_cast<char *>(outputs[item_index + kTransitionIndex]->addr) + schema_[item_index] * transition_index;
|
schema_[item_index] * transition_index;
|
||||||
MS_EXCEPTION_IF_CHECK_FAIL(memcpy_s(offset, outputs[item_index + kTransitionIndex]->size,
|
MS_EXCEPTION_IF_CHECK_FAIL(memcpy_s(offset, outputs[item_index + kTransitionIndex]->size,
|
||||||
transition[item_index]->addr, transition[item_index]->size) == EOK,
|
transition[item_index]->addr, transition[item_index]->size) == EOK,
|
||||||
"memcpy_s() failed.");
|
"memcpy_s() failed.");
|
||||||
|
|
|
@ -594,7 +594,7 @@ CNodePtr NeighborExchangeV2Fusion::CreateMiddleConcat(const FuncGraphPtr &graph,
|
||||||
if (concat_dim == kWDim) {
|
if (concat_dim == kWDim) {
|
||||||
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
|
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.begin() + 1);
|
||||||
} else {
|
} else {
|
||||||
int64_t bottom_num = AllToAllRealIds(4, recv_rank_ids);
|
int64_t bottom_num = AllToAllRealIds(kRankIdFour, recv_rank_ids);
|
||||||
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
|
concat_input_all.insert(concat_input_all.end(), all_to_all_v_outputs.begin() + bottom_num,
|
||||||
all_to_all_v_outputs.begin() + bottom_num + 1);
|
all_to_all_v_outputs.begin() + bottom_num + 1);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue