slowllama: fix cleanup

This commit is contained in:
Oleksandr Kuvshynov 2023-10-16 20:38:22 -04:00
parent 8adb19c518
commit 0922036ee8
1 changed files with 2 additions and 2 deletions

View File

@ -366,13 +366,13 @@ class Transformer(nn.Module):
del current
for i, (layer, rng_state, lora) in enumerate(zip(reversed(self.layers), reversed(rng_before), reversed(self.lora_layers))):
cleanup_cache()
cleanup_cache(device)
restore_rng_state(rng_state, device=device)
# first, do feed_forward
last_grad = self.backprop_w_lora(layer.feed_forward, last_grad)
# now, do attention
cleanup_cache()
cleanup_cache(device)
last_grad = self.backprop_w_lora(layer.attention, last_grad, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora'])
logging.log(level=logging.DEBUG, msg=f'combined: transformer block {i} done')