slowllama: fix cleanup
This commit is contained in:
parent
8adb19c518
commit
0922036ee8
|
@ -366,13 +366,13 @@ class Transformer(nn.Module):
|
|||
del current
|
||||
|
||||
for i, (layer, rng_state, lora) in enumerate(zip(reversed(self.layers), reversed(rng_before), reversed(self.lora_layers))):
|
||||
cleanup_cache()
|
||||
cleanup_cache(device)
|
||||
restore_rng_state(rng_state, device=device)
|
||||
# first, do feed_forward
|
||||
last_grad = self.backprop_w_lora(layer.feed_forward, last_grad)
|
||||
|
||||
# now, do attention
|
||||
cleanup_cache()
|
||||
cleanup_cache(device)
|
||||
last_grad = self.backprop_w_lora(layer.attention, last_grad, freqs_cos, freqs_sin, lora['q_lora'], lora['v_lora'])
|
||||
logging.log(level=logging.DEBUG, msg=f'combined: transformer block {i} done')
|
||||
|
||||
|
|
Loading…
Reference in New Issue