[Misc]Add customized information for models (#4132)

This commit is contained in:
Jee Li 2024-05-01 12:18:14 +08:00 committed by GitHub
parent c3845d82dc
commit d6f4bd7cdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 87 additions and 0 deletions

View File

@ -43,3 +43,18 @@ def test_models(
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
vllm_model = vllm_runner(model, dtype=dtype)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
del vllm_model

View File

@ -49,3 +49,18 @@ def test_models(
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
vllm_model = vllm_runner(model, dtype=dtype)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
del vllm_model

View File

@ -47,3 +47,10 @@ class Attention(nn.Module):
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
return s

View File

@ -67,6 +67,9 @@ class GeluAndMul(nn.Module):
ops.gelu_tanh_and_mul(out, x)
return out
def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'
class NewGELU(nn.Module):

View File

@ -64,3 +64,8 @@ class RMSNorm(nn.Module):
self.variance_epsilon,
)
return out
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s

View File

@ -181,6 +181,12 @@ class ReplicatedLinear(LinearBase):
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
@ -281,6 +287,14 @@ class ColumnParallelLinear(LinearBase):
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s
class MergedColumnParallelLinear(ColumnParallelLinear):
"""Packed linear layers with column parallelism.
@ -685,3 +699,11 @@ class RowParallelLinear(LinearBase):
output = output_
output_bias = self.bias
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s

View File

@ -70,6 +70,12 @@ class LogitsProcessor(nn.Module):
logits = logits[:, :self.org_vocab_size]
return logits
def extra_repr(self) -> str:
s = f"vocab_size={self.vocab_size}"
s += f", forg_vocab_size={self.org_vocab_size}"
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
return s
def _prune_hidden_states(
hidden_states: torch.Tensor,

View File

@ -156,6 +156,12 @@ class RotaryEmbedding(nn.Module):
self.cos_sin_cache, self.is_neox_style)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.

View File

@ -105,6 +105,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}'
s += f', tp_size={self.tp_size}'
return s
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.