mirror of https://github.com/vllm-project/vllm
[torch.compile] add deepseek v2 compile (#9775)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
97b61bfae6
commit
76ed5340f0
|
@ -28,6 +28,7 @@ from torch import nn
|
|||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
|
@ -403,6 +404,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepseekV2Model(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
|
Loading…
Reference in New Issue