This commit is contained in:
parent
317cd4fff1
commit
b23816d53c
|
@ -10,12 +10,13 @@ eval_interval = 300
|
|||
learning_rate = 1e-2
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
eval_iters = 200
|
||||
n_embd = 16
|
||||
# ------------
|
||||
|
||||
torch.manual_seed(1337)
|
||||
|
||||
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
|
||||
with open('../data/Xiyou.txt', 'r', encoding='utf-8') as f:
|
||||
with open('input.txt', 'r', encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
# here are all the unique characters that occur in this text
|
||||
|
@ -60,13 +61,25 @@ def estimate_loss():
|
|||
# super simple bigram model
|
||||
class BigramLanguageModel(nn.Module):
|
||||
|
||||
# 我是孙悟空
|
||||
# 空是我悟孙
|
||||
|
||||
def __init__(self, vocab_size):
|
||||
super().__init__()
|
||||
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
|
||||
|
||||
# each token directly reads off the logits for the next token from a lookup table
|
||||
self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # (B,T) --> (B,T,n_embd)
|
||||
self.pos_embedding_table = nn.Embedding(block_size, n_embd) # (B,T) --> (B,T,)
|
||||
self.head = nn.Linear(n_embd,vocab_size) # (B,T,n_embd) @ (vocab_size,vocab_size) ---> (B,T,vocab_size)
|
||||
def forward(self, idx, targets=None):
|
||||
|
||||
logits = self.token_embedding_table(idx) # (B,T,C)
|
||||
# idx and targets are both (B,T) tensor of integers
|
||||
x_emb = self.token_embedding_table(idx) # (B,T) --> (B,T,n_embd)
|
||||
# torch.arange(T,device=device) T:8 1,2,3,4,5,6,7,8
|
||||
p_emb = self.pos_embedding_table(torch.arange(T,device=device))# (T -> T,n_embd)
|
||||
|
||||
x = x_emb + p_emb # (B,T,n_embd)
|
||||
|
||||
logits = self.head(x) # (B,T,n_embd) -> (B,T,vocab_size)
|
||||
|
||||
if targets is None:
|
||||
loss = None
|
||||
|
|
Loading…
Reference in New Issue