This commit is contained in:
明硕 2024-05-04 09:50:40 +08:00
parent 317cd4fff1
commit b23816d53c
1 changed files with 17 additions and 4 deletions

View File

@ -10,12 +10,13 @@ eval_interval = 300
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 16
# ------------
torch.manual_seed(1337)
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('../data/Xiyou.txt', 'r', encoding='utf-8') as f:
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
# here are all the unique characters that occur in this text
@ -60,13 +61,25 @@ def estimate_loss():
# super simple bigram model
class BigramLanguageModel(nn.Module):
# 我是孙悟空
# 空是我悟孙
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # (B,T) --> (B,T,n_embd)
self.pos_embedding_table = nn.Embedding(block_size, n_embd) # (B,T) --> (B,T,)
self.head = nn.Linear(n_embd,vocab_size) # (B,T,n_embd) @ (vocab_size,vocab_size) ---> (B,T,vocab_size)
def forward(self, idx, targets=None):
logits = self.token_embedding_table(idx) # (B,T,C)
# idx and targets are both (B,T) tensor of integers
x_emb = self.token_embedding_table(idx) # (B,T) --> (B,T,n_embd)
# torch.arange(T,device=device) T:8 1,2,3,4,5,6,7,8
p_emb = self.pos_embedding_table(torch.arange(T,device=device))# (T -> T,n_embd)
x = x_emb + p_emb # (B,T,n_embd)
logits = self.head(x) # (B,T,n_embd) -> (B,T,vocab_size)
if targets is None:
loss = None