import%20marimo%0A%0A__generated_with%20%3D%20%220.18.4%22%0Aapp%20%3D%20marimo.App()%0A%0Awith%20app.setup%3A%0A%20%20%20%20from%20importlib.metadata%20import%20version%0A%20%20%20%20from%20pathlib%20import%20Path%0A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20import%20tiktoken%0A%20%20%20%20import%20torch%0A%20%20%20%20from%20ch02%20import%20create_dataloader_v1%2C%20download_verdict_data%0A%20%20%20%20from%20ch04%20import%20GPTModel%2C%20generate_text_simple%0A%20%20%20%20from%20gpt_download%20import%20download_and_load_gpt2%0A%20%20%20%20from%20matplotlib.ticker%20import%20MaxNLocator%0A%0A%20%20%20%20_pkgs%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%22matplotlib%22%2C%0A%20%20%20%20%20%20%20%20%22numpy%22%2C%0A%20%20%20%20%20%20%20%20%22tiktoken%22%2C%0A%20%20%20%20%20%20%20%20%22torch%22%2C%0A%20%20%20%20%20%20%20%20%22tensorflow%22%2C%20%20%23%20For%20OpenAI's%20pretrained%20weights%0A%20%20%20%20%5D%0A%0A%20%20%20%20for%20_p%20in%20_pkgs%3A%0A%20%20%20%20%20%20%20%20print(f%22%7B_p%7D%20version%3A%20%7Bversion(_p)%7D%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Chapter%205%3A%20Pretraining%20on%20Unlabeled%20Data%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%205.1%20Evaluating%20generative%20text%20models%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%205.1.1%20Using%20GPT%20to%20generate%20text%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20GPT_CONFIG_124M%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22vocab_size%22%3A%2050257%2C%20%20%23%20Vocabulary%20size%0A%20%20%20%20%20%20%20%20%22context_length%22%3A%20256%2C%20%20%23%20Shortened%20context%20length%20(orig%3A%201024)%0A%20%20%20%20%20%20%20%20%22emb_dim%22%3A%20768%2C%20%20%23%20Embedding%20dimension%0A%20%20%20%20%20%20%20%20%22n_heads%22%3A%2012%2C%20%20%23%20Number%20of%20attention%20heads%0A%20%20%20%20%20%20%20%20%22n_layers%22%3A%2012%2C%20%20%23%20Number%20of%20layers%0A%20%20%20%20%20%20%20%20%22drop_rate%22%3A%200.1%2C%20%20%23%20Dropout%20rate%0A%20%20%20%20%20%20%20%20%22qkv_bias%22%3A%20False%2C%20%20%23%20Query-key-value%20bias%0A%20%20%20%20%7D%0A%0A%20%20%20%20torch.manual_seed(123)%0A%20%20%20%20random_model%20%3D%20GPTModel(GPT_CONFIG_124M)%0A%20%20%20%20random_model.eval()%3B%20%20%23%20Disable%20dropout%20during%20inference%0A%20%20%20%20return%20GPT_CONFIG_124M%2C%20random_model%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Utility%20function%20to%20convert%20from%20text%20to%20token%20ID%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20text_to_token_ids(text%2C%20tokenizer)%3A%0A%20%20%20%20%23%20EOT%20is%20needed%20for%20padding%0A%20%20%20%20encoded%20%3D%20tokenizer.encode(text%2C%20allowed_special%3D%7B%22%3C%7Cendoftext%7C%3E%22%7D)%0A%20%20%20%20encoded_tensor%20%3D%20torch.tensor(encoded).unsqueeze(0)%20%20%23%20add%20batch%20dimension%0A%20%20%20%20return%20encoded_tensor%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Utility%20function%20to%20convert%20from%20token%20ID%20to%20text%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20token_ids_to_text(token_ids%2C%20tokenizer)%3A%0A%20%20%20%20flat%20%3D%20token_ids.squeeze(0)%20%20%23%20remove%20batch%20dimension%0A%20%20%20%20return%20tokenizer.decode(flat.tolist())%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20an%20example%20to%20use%20the%20above%20functions.%0A%20%20%20%20The%20output%20does%20not%20make%20sense%20because%20it%20is%20not%20trained%20yet.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M%2C%20random_model)%3A%0A%20%20%20%20_start_context%20%3D%20%22Every%20effort%20moves%20you%22%0A%20%20%20%20tokenizer%20%3D%20tiktoken.get_encoding(%22gpt2%22)%0A%0A%20%20%20%20_token_ids%20%3D%20generate_text_simple(%0A%20%20%20%20%20%20%20%20model%3Drandom_model%2C%0A%20%20%20%20%20%20%20%20idx%3Dtext_to_token_ids(_start_context%2C%20tokenizer)%2C%0A%20%20%20%20%20%20%20%20max_new_tokens%3D10%2C%0A%20%20%20%20%20%20%20%20context_size%3DGPT_CONFIG_124M%5B%22context_length%22%5D%0A%20%20%20%20)%0A%0A%20%20%20%20print(%22Output%20text%3A%5Cn%22%2C%20token_ids_to_text(_token_ids%2C%20tokenizer))%0A%20%20%20%20return%20(tokenizer%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%205.1.2%20Calculating%20the%20text%20generation%20loss%3A%20cross-entropy%20and%20perplexity%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20To%20show%20the%20way%20to%20calculate%20training%20loss%2C%20these%20inputs%20and%20targets%20are%20used.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20inputs%20%3D%20torch.tensor(%5B%5B16833%2C%203626%2C%206100%5D%2C%20%20%20%23%20%5B%22every%20effort%20moves%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B40%2C%20%20%20%201107%2C%20588%5D%5D)%20%20%20%23%20%20%22I%20really%20like%22%5D%0A%0A%20%20%20%20targets%20%3D%20torch.tensor(%5B%5B3626%2C%206100%2C%20345%20%20%5D%2C%20%20%23%20%5B%22%20effort%20moves%20you%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B1107%2C%20%20588%2C%2011311%5D%5D)%20%23%20%20%22%20really%20like%20chocolate%22%5D%0A%20%20%20%20return%20inputs%2C%20targets%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Get%20probabilities%20of%20next%20token%20IDs%20by%20using%20random_model%20inference%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs%2C%20random_model)%3A%0A%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20logits%20%3D%20random_model(inputs)%0A%0A%20%20%20%20probas%20%3D%20torch.softmax(logits%2C%20dim%3D-1)%20%23%20Probability%20of%20each%20token%20in%20vocabulary%0A%20%20%20%20print(f%22%7Bprobas.shape%3D%7D%22)%20%23%20Shape%3A%20(batch_size%2C%20num_tokens%2C%20vocab_size)%0A%20%20%20%20return%20logits%2C%20probas%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Use%20greedy%20way%20to%20get%20next%20token%20IDs%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(probas)%3A%0A%20%20%20%20token_ids%20%3D%20torch.argmax(probas%2C%20dim%3D-1%2C%20keepdim%3DTrue)%0A%20%20%20%20print(%22Token%20IDs%3A%5Cn%22%2C%20token_ids)%0A%20%20%20%20return%20(token_ids%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Compare%20token%20IDs%20between%20targets%20and%20predicted%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(targets%2C%20token_ids%2C%20tokenizer)%3A%0A%20%20%20%20print(f%22Targets%20batch%201%3A%20%7Btoken_ids_to_text(targets%5B0%5D%2C%20tokenizer)%7D%22)%0A%20%20%20%20print(f%22Outputs%20batch%201%3A%20%7Btoken_ids_to_text(token_ids%5B0%5D.flatten()%2C%20tokenizer)%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Let's%20evaluate%20the%20differences%20quantitatively.%0A%0A%20%20%20%20Get%20the%20probabilities%20(likelihood)%20for%20targets%20in%20the%20above%20result.%0A%0A%20%20%20%20Remind%20the%20%60probas%60%20has%20(batch_size%2C%20num_tokens%2C%20vocab_size)%20shape.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(probas%2C%20targets)%3A%0A%20%20%20%20_text_idx%20%3D%200%0A%20%20%20%20target_probas_1%20%3D%20probas%5B_text_idx%2C%20%5B0%2C%201%2C%202%5D%2C%20targets%5B_text_idx%5D%5D%0A%20%20%20%20print(%22Text%201%3A%22%2C%20target_probas_1)%0A%0A%20%20%20%20_text_idx%20%3D%201%0A%20%20%20%20target_probas_2%20%3D%20probas%5B_text_idx%2C%20%5B0%2C%201%2C%202%5D%2C%20targets%5B_text_idx%5D%5D%0A%20%20%20%20print(%22Text%202%3A%22%2C%20target_probas_2)%0A%20%20%20%20return%20target_probas_1%2C%20target_probas_2%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Take%20%60log%60%20of%20these%20to%20get%20entropy%20(or%20log%20likelihood).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(target_probas_1%2C%20target_probas_2)%3A%0A%20%20%20%20%23%20Compute%20logarithm%20of%20all%20token%20probabilities%0A%20%20%20%20log_probas%20%3D%20torch.log(torch.cat((target_probas_1%2C%20target_probas_2)))%0A%20%20%20%20print(log_probas)%0A%20%20%20%20return%20(log_probas%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Get%20average%20scores%20of%20it.%0A%20%20%20%20This%20is%20called%20as%20cross%20entropy%20loss.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(log_probas)%3A%0A%20%20%20%20%23%20Calculate%20the%20average%20probability%20for%20each%20token%0A%20%20%20%20avg_log_probas%20%3D%20torch.mean(log_probas)%0A%20%20%20%20print(avg_log_probas)%0A%20%20%20%20return%20(avg_log_probas%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Minimization%20of%20minus%20log%20is%20preferred%20than%20maximization%20of%20plus%20log%20for%20implementations.%0A%0A%20%20%20%20This%20is%20also%20called%20as%20negative%20log%20likelihood%20(nll).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(avg_log_probas)%3A%0A%20%20%20%20_neg_avg_log_probas%20%3D%20avg_log_probas%20*%20-1%0A%20%20%20%20print(_neg_avg_log_probas)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20proceed%20these%20calculate%20by%20PyTorch.%0A%20%20%20%20Before%20doing%20so%2C%20check%20the%20shapes%20of%20logits%20and%20targets%20to%20compare.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(logits%2C%20targets)%3A%0A%20%20%20%20%23%20Logits%20have%20shape%20(batch_size%2C%20num_tokens%2C%20vocab_size)%0A%20%20%20%20print(%22Logits%20shape%3A%22%2C%20logits.shape)%0A%0A%20%20%20%20%23%20Targets%20have%20shape%20(batch_size%2C%20num_tokens)%0A%20%20%20%20print(%22Targets%20shape%3A%22%2C%20targets.shape)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20To%20apply%20PyTorch%20function%2C%20%60flatten%60%20is%20needed.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(logits%2C%20targets)%3A%0A%20%20%20%20logits_flat%20%3D%20logits.flatten(0%2C%201)%0A%20%20%20%20targets_flat%20%3D%20targets.flatten()%0A%0A%20%20%20%20print(%22Flattened%20logits%3A%22%2C%20logits_flat.shape)%0A%20%20%20%20print(%22Flattened%20targets%3A%22%2C%20targets_flat.shape)%0A%20%20%20%20return%20logits_flat%2C%20targets_flat%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Then%2C%20pass%20these%20to%20%60cross_entropy%60%20to%20take%20softmax%2C%20log%2C%20mean%20and%20the%20minus.%0A%20%20%20%20This%20results%20coincides%20the%20previous%20result.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(logits_flat%2C%20targets_flat)%3A%0A%20%20%20%20loss%20%3D%20torch.nn.functional.cross_entropy(logits_flat%2C%20targets_flat)%0A%20%20%20%20print(loss)%0A%20%20%20%20return%20(loss%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20exponential%20of%20cross%20entropy%20is%20called%20as%20perplexity.%0A%0A%20%20%20%20If%20the%20probabilistic%20distribution%20is%20uniform%20(the%20most%20uncertain%20case)%2C%20we%20can%20write%20the%20perplexity%20by%20using%20the%20size%20of%20vocabulary%20%24N%24%20as%0A%0A%20%20%20%20%24%24%0A%20%20%20%20PP%0A%20%20%20%20%3D%5Cexp%5Cleft(-%5Csum_%7Bi%3D1%7D%5EN%5Cfrac%7B1%7D%7BN%7D%5Clog%5Cfrac%7B1%7D%7BN%7D%5Cright)%0A%20%20%20%20%3D%5Cexp%5Cleft(-%5Clog%5Cfrac%7B1%7D%7BN%7D%5Cright)%3DN.%0A%20%20%20%20%24%24%0A%0A%20%20%20%20So%20it%20has%20meaning%20of%20uncertainty%20for%20prediction%20in%20the%20unit%20of%20number%20of%20vocabularies.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loss)%3A%0A%20%20%20%20_perplexity%20%3D%20torch.exp(loss)%0A%20%20%20%20print(_perplexity)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%205.1.3%20Calculating%20the%20training%20and%20validation%20set%20losses%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Calculate%20such%20loss%20on%20public%20and%20tiny%20dataset.%0A%20%20%20%20At%20first%2C%20it%20should%20be%20downloaded.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20file_path%20%3D%20download_verdict_data()%0A%20%20%20%20with%20open(file_path%2C%20%22r%22%2C%20encoding%3D%22utf-8%22)%20as%20_file%3A%0A%20%20%20%20%20%20%20%20text_data%20%3D%20_file.read()%0A%20%20%20%20return%20(text_data%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20See%20the%20top.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(text_data)%3A%0A%20%20%20%20%23%20First%2099%20characters%0A%20%20%20%20print(text_data%5B%3A99%5D)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20See%20the%20tail%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(text_data)%3A%0A%20%20%20%20%23%20Last%2099%20characters%0A%20%20%20%20print(text_data%5B-99%3A%5D)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20See%20the%20scale%20of%20this%20dataset.%20It%20is%20tiny%20and%20enough%20to%20try%20training%20trial.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(text_data%2C%20tokenizer)%3A%0A%20%20%20%20_total_characters%20%3D%20len(text_data)%0A%20%20%20%20total_tokens%20%3D%20len(tokenizer.encode(text_data))%0A%0A%20%20%20%20print(%22Characters%3A%22%2C%20_total_characters)%0A%20%20%20%20print(%22Tokens%3A%22%2C%20total_tokens)%0A%20%20%20%20return%20(total_tokens%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Split%20the%20dataset%20and%20create%20dataloaders%20in%0A%20%20%20%20%24%24%0A%20%20%20%20%5Ctext%7Btrain%7D%3A%5Ctext%7Bvalid%7D%3D90%5C%25%3A10%5C%25.%0A%20%20%20%20%24%24%0A%0A%20%20%20%20The%20function%20%60create_dataloader_v1%60%20is%20defined%20on%20%5Bthe%20Chapter%202%5D(.%2Fch02.py).%0A%20%20%20%20See%20the%20definition%20in%20there.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M%2C%20text_data)%3A%0A%20%20%20%20%23%20Train%2Fvalidation%20ratio%0A%20%20%20%20train_ratio%20%3D%200.90%0A%20%20%20%20split_idx%20%3D%20int(train_ratio%20*%20len(text_data))%0A%20%20%20%20train_data%20%3D%20text_data%5B%3Asplit_idx%5D%0A%20%20%20%20val_data%20%3D%20text_data%5Bsplit_idx%3A%5D%0A%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20train_loader%20%3D%20create_dataloader_v1(%0A%20%20%20%20%20%20%20%20train_data%2C%0A%20%20%20%20%20%20%20%20batch_size%3D2%2C%20%20%20%23%20to%20save%20computation%20resources%0A%20%20%20%20%20%20%20%20max_length%3DGPT_CONFIG_124M%5B%22context_length%22%5D%2C%0A%20%20%20%20%20%20%20%20stride%3DGPT_CONFIG_124M%5B%22context_length%22%5D%2C%0A%20%20%20%20%20%20%20%20drop_last%3DTrue%2C%0A%20%20%20%20%20%20%20%20shuffle%3DTrue%2C%0A%20%20%20%20%20%20%20%20num_workers%3D0%0A%20%20%20%20)%0A%0A%20%20%20%20val_loader%20%3D%20create_dataloader_v1(%0A%20%20%20%20%20%20%20%20val_data%2C%0A%20%20%20%20%20%20%20%20batch_size%3D2%2C%0A%20%20%20%20%20%20%20%20max_length%3DGPT_CONFIG_124M%5B%22context_length%22%5D%2C%0A%20%20%20%20%20%20%20%20stride%3DGPT_CONFIG_124M%5B%22context_length%22%5D%2C%0A%20%20%20%20%20%20%20%20drop_last%3DFalse%2C%0A%20%20%20%20%20%20%20%20shuffle%3DFalse%2C%0A%20%20%20%20%20%20%20%20num_workers%3D0%0A%20%20%20%20)%0A%20%20%20%20return%20train_loader%2C%20train_ratio%2C%20val_loader%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Sanity%20check%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M%2C%20total_tokens%2C%20train_ratio)%3A%0A%20%20%20%20if%20total_tokens%20*%20(train_ratio)%20%3C%20GPT_CONFIG_124M%5B%22context_length%22%5D%3A%0A%20%20%20%20%20%20%20%20print(%22Not%20enough%20tokens%20for%20the%20training%20loader.%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22Try%20to%20lower%20the%20%60GPT_CONFIG_124M%5B'context_length'%5D%60%20or%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22increase%20the%20%60training_ratio%60%22)%0A%0A%20%20%20%20if%20total_tokens%20*%20(1-train_ratio)%20%3C%20GPT_CONFIG_124M%5B%22context_length%22%5D%3A%0A%20%20%20%20%20%20%20%20print(%22Not%20enough%20tokens%20for%20the%20validation%20loader.%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22Try%20to%20lower%20the%20%60GPT_CONFIG_124M%5B'context_length'%5D%60%20or%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%22decrease%20the%20%60training_ratio%60%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Try%20the%20dataloader%20to%20verify%20the%20behaviors.%0A%20%20%20%20These%20are%20all%20batches%20to%20be%20used.%0A%20%20%20%20Surely%2C%20the%20number%20of%20train%20batch%20is%209%20and%20the%20number%20of%20validation%20batch%20is%201%20that%20reflects%20the%20ratio.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(train_loader%2C%20val_loader)%3A%0A%20%20%20%20print(%22Train%20loader%3A%22)%0A%20%20%20%20for%20_x%2C%20_y%20in%20train_loader%3A%0A%20%20%20%20%20%20%20%20print(_x.shape%2C%20_y.shape)%0A%0A%20%20%20%20print(%22%5CnValidation%20loader%3A%22)%0A%20%20%20%20for%20_x%2C%20_y%20in%20val_loader%3A%0A%20%20%20%20%20%20%20%20print(_x.shape%2C%20_y.shape)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Check%20the%20total%20number%20of%20tokens.%20It%20is%20just%20about%205k.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(train_loader%2C%20val_loader)%3A%0A%20%20%20%20_train_tokens%20%3D%200%0A%20%20%20%20for%20_input_batch%2C%20_target_batch%20in%20train_loader%3A%0A%20%20%20%20%20%20%20%20_train_tokens%20%2B%3D%20_input_batch.numel()%0A%0A%20%20%20%20_val_tokens%20%3D%200%0A%20%20%20%20for%20_input_batch%2C%20_target_batch%20in%20val_loader%3A%0A%20%20%20%20%20%20%20%20_val_tokens%20%2B%3D%20_input_batch.numel()%0A%0A%20%20%20%20print(%22Training%20tokens%3A%22%2C%20_train_tokens)%0A%20%20%20%20print(%22Validation%20tokens%3A%22%2C%20_val_tokens)%0A%20%20%20%20print(%22All%20tokens%3A%22%2C%20_train_tokens%20%2B%20_val_tokens)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20a%20function%20to%20calculate%20cross%20entropy%20loss%20like%20previous.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20calc_loss_batch(input_batch%2C%20target_batch%2C%20model%2C%20device)%3A%0A%20%20%20%20input_batch%2C%20target_batch%20%3D%20input_batch.to(device)%2C%20target_batch.to(device)%0A%20%20%20%20logits%20%3D%20model(input_batch)%0A%20%20%20%20loss%20%3D%20torch.nn.functional.cross_entropy(logits.flatten(0%2C%201)%2C%20target_batch.flatten())%0A%20%20%20%20return%20loss%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Extend%20the%20functionality%20by%20the%20following%20function%20to%20entire%20dataloaders.%0A%20%20%20%20By%20setting%20the%20%60num_batches%60%20as%20smaller%20values%2C%20we%20can%20compute%20the%20losses%20easier.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20calc_loss_loader(data_loader%2C%20model%2C%20device%2C%20num_batches%3DNone)%3A%0A%20%20%20%20total_loss%20%3D%200.%0A%20%20%20%20if%20len(data_loader)%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20return%20float(%22nan%22)%0A%20%20%20%20elif%20num_batches%20is%20None%3A%0A%20%20%20%20%20%20%20%20num_batches%20%3D%20len(data_loader)%0A%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%23%20Reduce%20the%20number%20of%20batches%20to%20match%20the%20total%20number%20of%20batches%20in%20the%20data%20loader%0A%20%20%20%20%20%20%20%20%23%20if%20num_batches%20exceeds%20the%20number%20of%20batches%20in%20the%20data%20loader%0A%20%20%20%20%20%20%20%20num_batches%20%3D%20min(num_batches%2C%20len(data_loader))%0A%20%20%20%20for%20i%2C%20(input_batch%2C%20target_batch)%20in%20enumerate(data_loader)%3A%0A%20%20%20%20%20%20%20%20if%20i%20%3C%20num_batches%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20calc_loss_batch(input_batch%2C%20target_batch%2C%20model%2C%20device)%0A%20%20%20%20%20%20%20%20%20%20%20%20total_loss%20%2B%3D%20loss.item()%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20break%0A%20%20%20%20return%20total_loss%20%2F%20num_batches%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20utility%20function%20to%20detect%20suitable%20device%20for%20training%20and%20inference.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20get_torch_device()%3A%0A%20%20%20%20device%20%3D%20torch.device(%22cpu%22)%0A%20%20%20%20if%20torch.cuda.is_available()%3A%0A%20%20%20%20%20%20%20%20device%20%3D%20torch.device(%22cuda%22)%0A%20%20%20%20elif%20torch.backends.mps.is_available()%3A%0A%20%20%20%20%20%20%20%20%23%20Use%20PyTorch%202.9%20or%20newer%20for%20stable%20mps%20results%0A%20%20%20%20%20%20%20%20major%2C%20minor%20%3D%20map(int%2C%20torch.__version__.split(%22.%22)%5B%3A2%5D)%0A%20%20%20%20%20%20%20%20if%20(major%2C%20minor)%20%3E%3D%20(2%2C%209)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20device%20%3D%20torch.device(%22mps%22)%0A%20%20%20%20return%20device%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20the%20first%20trial%20to%20compute%20the%20loss%20for%20entire%20dataloaders.%0A%20%20%20%20We%20need%20to%20decrease%20the%20loss%20than%20these.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(random_model%2C%20train_loader%2C%20val_loader)%3A%0A%20%20%20%20device%20%3D%20get_torch_device()%0A%20%20%20%20print(f%22Using%20%7Bdevice%7D%20device.%22)%0A%0A%20%20%20%20random_model.to(device)%20%23%20no%20assignment%20model%20%3D%20model.to(device)%20necessary%20for%20nn.Module%20classes%0A%0A%20%20%20%20torch.manual_seed(123)%20%23%20For%20reproducibility%20due%20to%20the%20shuffling%20in%20the%20data%20loader%0A%0A%20%20%20%20with%20torch.no_grad()%3A%20%23%20Disable%20gradient%20tracking%20for%20efficiency%20because%20we%20are%20not%20training%2C%20yet%0A%20%20%20%20%20%20%20%20train_loss%20%3D%20calc_loss_loader(train_loader%2C%20random_model%2C%20device)%0A%20%20%20%20%20%20%20%20val_loss%20%3D%20calc_loss_loader(val_loader%2C%20random_model%2C%20device)%0A%0A%20%20%20%20print(%22Training%20loss%3A%22%2C%20train_loss)%0A%20%20%20%20print(%22Validation%20loss%3A%22%2C%20val_loss)%0A%20%20%20%20return%20(device%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%205.2%20Training%20an%20LLM%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20training%20process%20is%20like%20this.%20We%20need%20to%20define%20the%20%60evaluate_model()%60%20and%20%60generate_and_print_sample()%60%20later%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20train_model_simple(model%2C%20train_loader%2C%20val_loader%2C%20optimizer%2C%20device%2C%20num_epochs%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20eval_freq%2C%20eval_iter%2C%20start_context%2C%20tokenizer)%3A%0A%20%20%20%20%23%20Initialize%20lists%20to%20track%20losses%20and%20tokens%20seen%0A%20%20%20%20train_losses%2C%20val_losses%2C%20track_tokens_seen%20%3D%20%5B%5D%2C%20%5B%5D%2C%20%5B%5D%0A%20%20%20%20tokens_seen%2C%20global_step%20%3D%200%2C%20-1%0A%0A%20%20%20%20%23%20Main%20training%20loop%0A%20%20%20%20for%20epoch%20in%20range(num_epochs)%3A%0A%20%20%20%20%20%20%20%20model.train()%20%20%23%20Set%20model%20to%20training%20mode%0A%0A%20%20%20%20%20%20%20%20for%20input_batch%2C%20target_batch%20in%20train_loader%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20optimizer.zero_grad()%20%23%20Reset%20loss%20gradients%20from%20previous%20batch%20iteration%0A%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20calc_loss_batch(input_batch%2C%20target_batch%2C%20model%2C%20device)%0A%20%20%20%20%20%20%20%20%20%20%20%20loss.backward()%20%23%20Calculate%20loss%20gradients%0A%20%20%20%20%20%20%20%20%20%20%20%20optimizer.step()%20%23%20Update%20model%20weights%20using%20loss%20gradients%0A%20%20%20%20%20%20%20%20%20%20%20%20tokens_seen%20%2B%3D%20input_batch.numel()%0A%20%20%20%20%20%20%20%20%20%20%20%20global_step%20%2B%3D%201%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Optional%20evaluation%20step%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20global_step%20%25%20eval_freq%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20train_loss%2C%20val_loss%20%3D%20evaluate_model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20model%2C%20train_loader%2C%20val_loader%2C%20device%2C%20eval_iter)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20train_losses.append(train_loss)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20val_losses.append(val_loss)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20track_tokens_seen.append(tokens_seen)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(f%22Ep%20%7Bepoch%2B1%7D%20(Step%20%7Bglobal_step%3A06d%7D)%3A%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22Train%20loss%20%7Btrain_loss%3A.3f%7D%2C%20Val%20loss%20%7Bval_loss%3A.3f%7D%22)%0A%0A%20%20%20%20%20%20%20%20%23%20Print%20a%20sample%20text%20after%20each%20epoch%0A%20%20%20%20%20%20%20%20generate_and_print_sample(%0A%20%20%20%20%20%20%20%20%20%20%20%20model%2C%20tokenizer%2C%20device%2C%20start_context%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20return%20train_losses%2C%20val_losses%2C%20track_tokens_seen%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20a%20function%20to%20evaluate%20loss%20quantitatively%20for%20entire%20dataloaders.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20evaluate_model(model%2C%20train_loader%2C%20val_loader%2C%20device%2C%20eval_iter)%3A%0A%20%20%20%20model.eval()%20%20%20%20%23%20disable%20dropout%0A%20%20%20%20with%20torch.no_grad()%3A%20%20%20%23%20skip%20gradient%20calculations%0A%20%20%20%20%20%20%20%20train_loss%20%3D%20calc_loss_loader(train_loader%2C%20model%2C%20device%2C%20num_batches%3Deval_iter)%0A%20%20%20%20%20%20%20%20val_loss%20%3D%20calc_loss_loader(val_loader%2C%20model%2C%20device%2C%20num_batches%3Deval_iter)%0A%20%20%20%20model.train()%20%20%20%20%23%20enable%20dropout%0A%20%20%20%20return%20train_loss%2C%20val_loss%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20a%20function%20to%20evaluate%20the%20output%20qualitatively.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20generate_and_print_sample(model%2C%20tokenizer%2C%20device%2C%20start_context)%3A%0A%20%20%20%20model.eval()%0A%20%20%20%20context_size%20%3D%20model.pos_emb.weight.shape%5B0%5D%0A%20%20%20%20encoded%20%3D%20text_to_token_ids(start_context%2C%20tokenizer).to(device)%0A%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20token_ids%20%3D%20generate_text_simple(%0A%20%20%20%20%20%20%20%20%20%20%20%20model%3Dmodel%2C%20idx%3Dencoded%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20max_new_tokens%3D50%2C%20context_size%3Dcontext_size%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20decoded_text%20%3D%20token_ids_to_text(token_ids%2C%20tokenizer)%0A%20%20%20%20print(decoded_text.replace(%22%5Cn%22%2C%20%22%20%22))%20%20%23%20Compact%20print%20format%0A%20%20%20%20model.train()%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20use%20%60AdamW%60%20Optimizer%20to%20surpress%20overfitting.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M%2C%20device%2C%20tokenizer%2C%20train_loader%2C%20val_loader)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%20%20%20%20model%20%3D%20GPTModel(GPT_CONFIG_124M)%0A%20%20%20%20model.to(device)%0A%20%20%20%20optimizer%20%3D%20torch.optim.AdamW(model.parameters()%2C%20lr%3D0.0004%2C%20weight_decay%3D0.1)%0A%0A%20%20%20%20num_epochs%20%3D%2010%0A%20%20%20%20train_losses%2C%20val_losses%2C%20tokens_seen%20%3D%20train_model_simple(%0A%20%20%20%20%20%20%20%20model%2C%20train_loader%2C%20val_loader%2C%20optimizer%2C%20device%2C%0A%20%20%20%20%20%20%20%20num_epochs%3Dnum_epochs%2C%20eval_freq%3D5%2C%20eval_iter%3D5%2C%0A%20%20%20%20%20%20%20%20start_context%3D%22Every%20effort%20moves%20you%22%2C%20tokenizer%3Dtokenizer%0A%20%20%20%20)%0A%20%20%20%20return%20model%2C%20num_epochs%2C%20optimizer%2C%20tokens_seen%2C%20train_losses%2C%20val_losses%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Create%20directory%20to%20save%20plots%20for%20the%20traning%20process%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(project_root)%3A%0A%20%20%20%20data_root%20%3D%20project_root%20%2F%20%22data%22%20%2F%20%22ch05%22%0A%20%20%20%20data_root.mkdir(parents%3DTrue%2C%20exist_ok%3DTrue)%0A%20%20%20%20return%20(data_root%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Define%20the%20function%20to%20plot%20training%20processes%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20plot_losses(data_root)%3A%0A%20%20%20%20def%20plot_losses(epochs_seen%2C%20tokens_seen%2C%20train_losses%2C%20val_losses)%3A%0A%20%20%20%20%20%20%20%20fig%2C%20ax1%20%3D%20plt.subplots(figsize%3D(5%2C%203))%0A%0A%20%20%20%20%20%20%20%20%23%20Plot%20training%20and%20validation%20loss%20against%20epochs%0A%20%20%20%20%20%20%20%20ax1.plot(epochs_seen%2C%20train_losses%2C%20label%3D%22Training%20loss%22)%0A%20%20%20%20%20%20%20%20ax1.plot(epochs_seen%2C%20val_losses%2C%20linestyle%3D%22-.%22%2C%20label%3D%22Validation%20loss%22)%0A%20%20%20%20%20%20%20%20ax1.set_xlabel(%22Epochs%22)%0A%20%20%20%20%20%20%20%20ax1.set_ylabel(%22Loss%22)%0A%20%20%20%20%20%20%20%20ax1.legend(loc%3D%22upper%20right%22)%0A%20%20%20%20%20%20%20%20ax1.xaxis.set_major_locator(MaxNLocator(integer%3DTrue))%20%20%23%20only%20show%20integer%20labels%20on%20x-axis%0A%0A%20%20%20%20%20%20%20%20%23%20Create%20a%20second%20x-axis%20for%20tokens%20seen%0A%20%20%20%20%20%20%20%20ax2%20%3D%20ax1.twiny()%20%20%23%20Create%20a%20second%20x-axis%20that%20shares%20the%20same%20y-axis%0A%20%20%20%20%20%20%20%20ax2.plot(tokens_seen%2C%20train_losses%2C%20alpha%3D0)%20%20%23%20Invisible%20plot%20for%20aligning%20ticks%0A%20%20%20%20%20%20%20%20ax2.set_xlabel(%22Tokens%20seen%22)%0A%0A%20%20%20%20%20%20%20%20fig.tight_layout()%20%20%23%20Adjust%20layout%20to%20make%20room%0A%20%20%20%20%20%20%20%20plt.savefig(data_root%20%2F%20%22loss-plot.pdf%22)%0A%20%20%20%20%20%20%20%20plt.show()%0A%20%20%20%20return%20(plot_losses%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20result%20shows%20overfitting%20because%20of%20traning%20with%20tiny%20dataset.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(num_epochs%2C%20plot_losses%2C%20tokens_seen%2C%20train_losses%2C%20val_losses)%3A%0A%20%20%20%20epochs_tensor%20%3D%20torch.linspace(0%2C%20num_epochs%2C%20len(train_losses))%0A%20%20%20%20plot_losses(epochs_tensor%2C%20tokens_seen%2C%20train_losses%2C%20val_losses)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%205.3%20Decoding%20strategies%20to%20control%20randomness%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20%60generate_text_simple%60%20is%20based%20on%20greedy%20decoding%20strategy%20and%20it%20is%20deterministic.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M%2C%20model%2C%20tokenizer)%3A%0A%20%20%20%20%23%20NEW%3A%20use%20CPU%20here%20as%20inference%20is%20cheap%20with%20%0A%20%20%20%20%23%20this%20model%20and%20to%20ensure%20readers%20get%20same%20results%20in%20the%0A%20%20%20%20%23%20remaining%20sections%20of%20this%20book%0A%20%20%20%20inference_device%20%3D%20torch.device(%22cpu%22)%0A%0A%20%20%20%20model.to(inference_device)%0A%20%20%20%20model.eval()%0A%0A%20%20%20%20_token_ids%20%3D%20generate_text_simple(%0A%20%20%20%20%20%20%20%20model%3Dmodel%2C%0A%20%20%20%20%20%20%20%20idx%3Dtext_to_token_ids(%22Every%20effort%20moves%20you%22%2C%20tokenizer).to(inference_device)%2C%0A%20%20%20%20%20%20%20%20max_new_tokens%3D25%2C%0A%20%20%20%20%20%20%20%20context_size%3DGPT_CONFIG_124M%5B%22context_length%22%5D%0A%20%20%20%20)%0A%0A%20%20%20%20print(%22Output%20text%3A%5Cn%22%2C%20token_ids_to_text(_token_ids%2C%20tokenizer))%0A%20%20%20%20return%20(inference_device%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%205.3.1%20Temperature%20scaling%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20an%20example%20to%20show%20difference%20between%20deterministic%20and%20probabilistic%20strategies.%0A%20%20%20%20This%20outputs%20deterministic%20result%20by%20greedy%20decoding.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20vocab%20%3D%20%7B%20%0A%20%20%20%20%20%20%20%20%22closer%22%3A%200%2C%0A%20%20%20%20%20%20%20%20%22every%22%3A%201%2C%20%0A%20%20%20%20%20%20%20%20%22effort%22%3A%202%2C%20%0A%20%20%20%20%20%20%20%20%22forward%22%3A%203%2C%0A%20%20%20%20%20%20%20%20%22inches%22%3A%204%2C%0A%20%20%20%20%20%20%20%20%22moves%22%3A%205%2C%20%0A%20%20%20%20%20%20%20%20%22pizza%22%3A%206%2C%0A%20%20%20%20%20%20%20%20%22toward%22%3A%207%2C%0A%20%20%20%20%20%20%20%20%22you%22%3A%208%2C%0A%20%20%20%20%7D%20%0A%0A%20%20%20%20inverse_vocab%20%3D%20%7Bv%3A%20k%20for%20k%2C%20v%20in%20vocab.items()%7D%0A%0A%20%20%20%20%23%20Suppose%20input%20is%20%22every%20effort%20moves%20you%22%2C%20and%20the%20LLM%0A%20%20%20%20%23%20returns%20the%20following%20logits%20for%20the%20next%20token%3A%0A%20%20%20%20next_token_logits%20%3D%20torch.tensor(%0A%20%20%20%20%20%20%20%20%5B4.51%2C%200.89%2C%20-1.90%2C%206.75%2C%201.63%2C%20-1.62%2C%20-1.89%2C%206.28%2C%201.79%5D%0A%20%20%20%20)%0A%0A%20%20%20%20next_probas%20%3D%20torch.softmax(next_token_logits%2C%20dim%3D0)%0A%20%20%20%20next_token_id%20%3D%20torch.argmax(next_probas).item()%0A%0A%20%20%20%20%23%20The%20next%20generated%20token%20is%20then%20as%20follows%3A%0A%20%20%20%20print(inverse_vocab%5Bnext_token_id%5D)%0A%20%20%20%20return%20inverse_vocab%2C%20next_probas%2C%20next_token_logits%2C%20vocab%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Try%20probabilistic%20sampling%20based%20of%20probabilities%20obtained%20by%20the%20logits.%0A%20%20%20%20Even%20the%20%60forward%60%20occurs%20in%20high%20probability%2C%20it%20also%20outputs%20the%20same%20word.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inverse_vocab%2C%20next_probas)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%20%20%20%20next_token_id_multinomial%20%3D%20torch.multinomial(next_probas%2C%20num_samples%3D1).item()%0A%20%20%20%20print(inverse_vocab%5Bnext_token_id_multinomial%5D)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Try%20multiple%20times%20to%20see%20the%20probabilistic%20behaviors.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inverse_vocab)%3A%0A%20%20%20%20def%20print_sampled_tokens(probas)%3A%0A%20%20%20%20%20%20%20%20torch.manual_seed(123)%20%23%20Manual%20seed%20for%20reproducibility%0A%20%20%20%20%20%20%20%20sample%20%3D%20%5Btorch.multinomial(probas%2C%20num_samples%3D1).item()%20for%20i%20in%20range(1_000)%5D%0A%20%20%20%20%20%20%20%20sampled_ids%20%3D%20torch.bincount(torch.tensor(sample)%2C%20minlength%3Dlen(probas))%0A%20%20%20%20%20%20%20%20for%20i%2C%20freq%20in%20enumerate(sampled_ids)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20print(f%22%7Bfreq%7D%20x%20%7Binverse_vocab%5Bi%5D%7D%22)%0A%20%20%20%20return%20(print_sampled_tokens%2C)%0A%0A%0A%40app.cell%0Adef%20_(next_probas%2C%20print_sampled_tokens)%3A%0A%20%20%20%20print_sampled_tokens(next_probas)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20To%20control%20the%20probability%2C%20introduce%20temprature%20scaling%20like%0A%20%20%20%20%24%24%0A%20%20%20%20p(x%3B%5Cbeta)%20%5Cpropto%20%5Cexp%5Cleft(%5Cbeta%20p%5Cright)%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20softmax_with_temperature(logits%2C%20temperature)%3A%0A%20%20%20%20scaled_logits%20%3D%20logits%20%2F%20temperature%0A%20%20%20%20return%20torch.softmax(scaled_logits%2C%20dim%3D0)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Try%20different%20values%20of%20the%20tempretures.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(next_token_logits)%3A%0A%20%20%20%20%23%20Temperature%20values%0A%20%20%20%20temperatures%20%3D%20%5B1%2C%200.1%2C%205%5D%20%20%23%20Original%2C%20higher%20confidence%2C%20and%20lower%20confidence%0A%0A%20%20%20%20%23%20Calculate%20scaled%20probabilities%0A%20%20%20%20scaled_probas%20%3D%20%5Bsoftmax_with_temperature(next_token_logits%2C%20T)%20for%20T%20in%20temperatures%5D%0A%20%20%20%20return%20scaled_probas%2C%20temperatures%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Plot%20it.%20It%20shows%20higher%20temprature%20make%20the%20distribution%20more%20uniform.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(data_root%2C%20scaled_probas%2C%20temperatures%2C%20vocab)%3A%0A%20%20%20%20%23%20Plotting%0A%20%20%20%20x%20%3D%20torch.arange(len(vocab))%0A%20%20%20%20bar_width%20%3D%200.15%0A%0A%20%20%20%20fig%2C%20ax%20%3D%20plt.subplots(figsize%3D(5%2C%203))%0A%20%20%20%20for%20i%2C%20T%20in%20enumerate(temperatures)%3A%0A%20%20%20%20%20%20%20%20rects%20%3D%20ax.bar(x%20%2B%20i%20*%20bar_width%2C%20scaled_probas%5Bi%5D%2C%20bar_width%2C%20label%3Df'Temperature%20%3D%20%7BT%7D')%0A%0A%20%20%20%20ax.set_ylabel('Probability')%0A%20%20%20%20ax.set_xticks(x)%0A%20%20%20%20ax.set_xticklabels(vocab.keys()%2C%20rotation%3D90)%0A%20%20%20%20ax.legend()%0A%0A%20%20%20%20plt.tight_layout()%0A%20%20%20%20plt.savefig(data_root%20%2F%20%22temperature-plot.pdf%22)%0A%20%20%20%20plt.show()%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%60Temperature%3D0.1%60%20case%20causes%20more%20greedy%20like%20distribution.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(print_sampled_tokens%2C%20scaled_probas)%3A%0A%20%20%20%20print_sampled_tokens(scaled_probas%5B1%5D)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%60Temperature%3D5%60%20cause%20more%20diversity%20but%20more%20nonsense%20phrases%20like%20%60pizza%60.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(print_sampled_tokens%2C%20scaled_probas)%3A%0A%20%20%20%20print_sampled_tokens(scaled_probas%5B2%5D)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%205.3.2%20Top-k%20sampling%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20top-k%20sampling%20strategy%20is%20another%20probabilistic%20model%20using%20cutoff.%0A%20%20%20%20We%20only%20use%20top%20%24k%24%20candidates%20in%20the%20probabilistic%20ranking%20and%20ignore%20others.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(next_token_logits)%3A%0A%20%20%20%20top_k%20%3D%203%0A%20%20%20%20top_logits%2C%20top_pos%20%3D%20torch.topk(next_token_logits%2C%20top_k)%0A%0A%20%20%20%20print(%22Top%20logits%3A%22%2C%20top_logits)%0A%20%20%20%20print(%22Top%20positions%3A%22%2C%20top_pos)%0A%20%20%20%20return%20(top_logits%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Such%20ignored%20candidate%20logits%20are%20masked%20by%20-inf.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(next_token_logits%2C%20top_logits)%3A%0A%20%20%20%20new_logits%20%3D%20torch.where(%0A%20%20%20%20%20%20%20%20condition%3Dnext_token_logits%20%3C%20top_logits%5B-1%5D%2C%0A%20%20%20%20%20%20%20%20input%3Dtorch.tensor(float(%22-inf%22))%2C%20%0A%20%20%20%20%20%20%20%20other%3Dnext_token_logits%0A%20%20%20%20)%0A%0A%20%20%20%20print(new_logits)%0A%20%20%20%20return%20(new_logits%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20result%20probabilities%20are%20these.%0A%20%20%20%20The%20predicted%20token%20will%20be%20sampled%20base%20on%20this%20result.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(new_logits)%3A%0A%20%20%20%20topk_probas%20%3D%20torch.softmax(new_logits%2C%20dim%3D0)%0A%20%20%20%20print(topk_probas)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%205.3.3%20Modifying%20the%20text%20generation%20function%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Include%20temperature%20and%20tok-k%20sampling%20to%20%60generate_text_simple()%60.%0A%20%20%20%20This%20generation%20process%20continues%20until%20the%20EOS%20token%20appears.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20generate(model%2C%20idx%2C%20max_new_tokens%2C%20context_size%2C%20temperature%3D0.0%2C%20top_k%3DNone%2C%20eos_id%3DNone)%3A%0A%20%20%20%20%23%20For-loop%20is%20the%20same%20as%20before%3A%20Get%20logits%2C%20and%20only%20focus%20on%20last%20time%20step%0A%20%20%20%20for%20_%20in%20range(max_new_tokens)%3A%0A%20%20%20%20%20%20%20%20idx_cond%20%3D%20idx%5B%3A%2C%20-context_size%3A%5D%0A%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20model(idx_cond)%0A%20%20%20%20%20%20%20%20logits%20%3D%20logits%5B%3A%2C%20-1%2C%20%3A%5D%0A%0A%20%20%20%20%20%20%20%20%23%20New%3A%20Filter%20logits%20with%20top_k%20sampling%0A%20%20%20%20%20%20%20%20if%20top_k%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Keep%20only%20top_k%20values%0A%20%20%20%20%20%20%20%20%20%20%20%20top_logits%2C%20_%20%3D%20torch.topk(logits%2C%20top_k)%0A%20%20%20%20%20%20%20%20%20%20%20%20min_val%20%3D%20top_logits%5B%3A%2C%20-1%5D%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20mask%20other%20logits%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20torch.where(logits%20%3C%20min_val%2C%20torch.tensor(float(%22-inf%22)).to(logits.device)%2C%20logits)%0A%0A%20%20%20%20%20%20%20%20%23%20New%3A%20Apply%20temperature%20scaling%0A%20%20%20%20%20%20%20%20if%20temperature%20%3E%200.0%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20logits%20%2F%20temperature%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20New%20(not%20in%20book)%3A%20numerical%20stability%20tip%20to%20get%20equivalent%20results%20on%20mps%20device%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20subtract%20rowwise%20max%20before%20softmax%0A%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20logits%20-%20logits.max(dim%3D-1%2C%20keepdim%3DTrue).values%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Apply%20softmax%20to%20get%20probabilities%0A%20%20%20%20%20%20%20%20%20%20%20%20probs%20%3D%20torch.softmax(logits%2C%20dim%3D-1)%20%20%23%20(batch_size%2C%20context_len)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Sample%20from%20the%20distribution%0A%20%20%20%20%20%20%20%20%20%20%20%20idx_next%20%3D%20torch.multinomial(probs%2C%20num_samples%3D1)%20%20%23%20(batch_size%2C%201)%0A%0A%20%20%20%20%20%20%20%20%23%20Otherwise%20same%20as%20before%3A%20get%20idx%20of%20the%20vocab%20entry%20with%20the%20highest%20logits%20value%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20greedy%20sampling%0A%20%20%20%20%20%20%20%20%20%20%20%20idx_next%20%3D%20torch.argmax(logits%2C%20dim%3D-1%2C%20keepdim%3DTrue)%20%20%23%20(batch_size%2C%201)%0A%0A%20%20%20%20%20%20%20%20if%20idx_next%20%3D%3D%20eos_id%3A%20%20%23%20Stop%20generating%20early%20if%20end-of-sequence%20token%20is%20encountered%20and%20eos_id%20is%20specified%0A%20%20%20%20%20%20%20%20%20%20%20%20break%0A%0A%20%20%20%20%20%20%20%20%23%20Same%20as%20before%3A%20append%20sampled%20index%20to%20the%20running%20sequence%0A%20%20%20%20%20%20%20%20idx%20%3D%20torch.cat((idx%2C%20idx_next)%2C%20dim%3D1)%20%20%23%20(batch_size%2C%20num_tokens%2B1)%0A%0A%20%20%20%20return%20idx%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20result%20is%20different%20with%20the%20previous.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M%2C%20inference_device%2C%20model%2C%20tokenizer)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20_token_ids%20%3D%20generate(%0A%20%20%20%20%20%20%20%20model%3Dmodel%2C%0A%20%20%20%20%20%20%20%20idx%3Dtext_to_token_ids(%22Every%20effort%20moves%20you%22%2C%20tokenizer).to(inference_device)%2C%0A%20%20%20%20%20%20%20%20max_new_tokens%3D15%2C%0A%20%20%20%20%20%20%20%20context_size%3DGPT_CONFIG_124M%5B%22context_length%22%5D%2C%0A%20%20%20%20%20%20%20%20top_k%3D25%2C%0A%20%20%20%20%20%20%20%20temperature%3D1.4%0A%20%20%20%20)%0A%0A%20%20%20%20print(%22Output%20text%3A%5Cn%22%2C%20token_ids_to_text(_token_ids%2C%20tokenizer))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%205.4%20Loading%20and%20saving%20model%20weights%20in%20PyTorch%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20save%20the%20trained%20moddel%20like%20this.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model)%3A%0A%20%20%20%20project_root%20%3D%20Path(__file__).parent.parent.parent%0A%20%20%20%20model_dir_path%20%3D%20project_root%20%2F%20%22models%22%20%2F%20%22ch05%22%0A%20%20%20%20train_model_path%20%3D%20model_dir_path%20%2F%20%22model.pth%22%0A%20%20%20%20train_model_path.parent.mkdir(parents%3DTrue%2C%20exist_ok%3DTrue)%0A%0A%20%20%20%20torch.save(model.state_dict()%2C%20train_model_path)%0A%20%20%20%20return%20model_dir_path%2C%20project_root%2C%20train_model_path%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20loading%20is%20also%20easy.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M%2C%20device%2C%20train_model_path)%3A%0A%20%20%20%20loaded_model%20%3D%20GPTModel(GPT_CONFIG_124M)%0A%20%20%20%20print(%22Device%3A%22%2C%20device)%0A%0A%20%20%20%20loaded_model.load_state_dict(torch.load(train_model_path%2C%20map_location%3Ddevice%2C%20weights_only%3DTrue))%0A%20%20%20%20loaded_model.eval()%3B%20%20%20%20%23%20disable%20dropout%0A%20%20%20%20return%20(loaded_model%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20also%20save%20the%20optimizer%20status%2C%20and%20continue%20the%20training%20later.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(loaded_model%2C%20model_dir_path%2C%20optimizer)%3A%0A%20%20%20%20train_model_optimizer_path%20%3D%20model_dir_path%20%2F%20%22model_and_optimizer.pth%22%0A%20%20%20%20train_model_optimizer_path.parent.mkdir(parents%3DTrue%2C%20exist_ok%3DTrue)%0A%0A%20%20%20%20torch.save(%7B%0A%20%20%20%20%20%20%20%20%22model_state_dict%22%3A%20loaded_model.state_dict()%2C%0A%20%20%20%20%20%20%20%20%22optimizer_state_dict%22%3A%20optimizer.state_dict()%2C%0A%20%20%20%20%20%20%20%20%7D%2C%20%0A%20%20%20%20%20%20%20%20train_model_optimizer_path%0A%20%20%20%20)%0A%20%20%20%20return%20(train_model_optimizer_path%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Such%20optimizer%20status%20can%20be%20loaded%20as%20follows.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M%2C%20train_model_optimizer_path)%3A%0A%20%20%20%20_checkpoint%20%3D%20torch.load(train_model_optimizer_path%2C%20weights_only%3DTrue)%0A%0A%20%20%20%20_model%20%3D%20GPTModel(GPT_CONFIG_124M)%0A%20%20%20%20_model.load_state_dict(_checkpoint%5B%22model_state_dict%22%5D)%0A%0A%20%20%20%20_optimizer%20%3D%20torch.optim.AdamW(_model.parameters()%2C%20lr%3D0.0005%2C%20weight_decay%3D0.1)%0A%20%20%20%20_optimizer.load_state_dict(_checkpoint%5B%22optimizer_state_dict%22%5D)%0A%20%20%20%20_model.train()%3B%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%205.5%20Loading%20pretrained%20weights%20from%20OpenAI%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20will%20use%20trained%20GPT2%20model%20from%20OpenAI.%0A%20%20%20%20It%20is%20defined%20as%20TensorFlow%20model%2C%20so%20we%20need%20to%20convert%20it%20to%20PyTorch%20model.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20print(%22TensorFlow%20version%3A%22%2C%20version(%22tensorflow%22))%0A%20%20%20%20print(%22tqdm%20version%3A%22%2C%20version(%22tqdm%22))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20%60download_and_load_gpt2%60%20function%20is%20defined%20at%20%5Bgpt_download.py%5D(.%2Fgpt_download.py).%0A%20%20%20%20The%20code%20is%20just%20downloading%20files%20and%20read%20the%20data%20only.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(project_root)%3A%0A%20%20%20%20models_dir%20%3D%20project_root%20%2F%20%22models%22%20%2F%20%22gpt2%22%0A%20%20%20%20models_dir.mkdir(parents%3DTrue%2C%20exist_ok%3DTrue)%0A%20%20%20%20settings%2C%20params%20%3D%20download_and_load_gpt2(model_size%3D%22124M%22%2C%20models_dir%3Dmodels_dir)%0A%20%20%20%20return%20params%2C%20settings%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20These%20are%20hyperparameters.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(settings)%3A%0A%20%20%20%20print(%22Settings%3A%22%2C%20settings)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20These%20are%20keys%20of%20parameters.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(params)%3A%0A%20%20%20%20print(%22Parameter%20dictionary%20keys%3A%22%2C%20params.keys())%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20access%20each%20parameters%20like%20this.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(params)%3A%0A%20%20%20%20print(params%5B%22wte%22%5D)%0A%20%20%20%20print(%22Token%20embedding%20weight%20tensor%20dimensions%3A%22%2C%20params%5B%22wte%22%5D.shape)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20files%20supports%20several%20architecture%20of%20GPT2.%20We%20will%20use%20the%20smallest%20model.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GPT_CONFIG_124M)%3A%0A%20%20%20%20%23%20Define%20model%20configurations%20in%20a%20dictionary%20for%20compactness%0A%20%20%20%20model_configs%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22gpt2-small%20(124M)%22%3A%20%7B%22emb_dim%22%3A%20768%2C%20%22n_layers%22%3A%2012%2C%20%22n_heads%22%3A%2012%7D%2C%0A%20%20%20%20%20%20%20%20%22gpt2-medium%20(355M)%22%3A%20%7B%22emb_dim%22%3A%201024%2C%20%22n_layers%22%3A%2024%2C%20%22n_heads%22%3A%2016%7D%2C%0A%20%20%20%20%20%20%20%20%22gpt2-large%20(774M)%22%3A%20%7B%22emb_dim%22%3A%201280%2C%20%22n_layers%22%3A%2036%2C%20%22n_heads%22%3A%2020%7D%2C%0A%20%20%20%20%20%20%20%20%22gpt2-xl%20(1558M)%22%3A%20%7B%22emb_dim%22%3A%201600%2C%20%22n_layers%22%3A%2048%2C%20%22n_heads%22%3A%2025%7D%2C%0A%20%20%20%20%7D%0A%0A%20%20%20%20%23%20Copy%20the%20base%20configuration%20and%20update%20with%20specific%20model%20settings%0A%20%20%20%20model_name%20%3D%20%22gpt2-small%20(124M)%22%20%20%23%20Example%20model%20name%0A%20%20%20%20NEW_CONFIG%20%3D%20GPT_CONFIG_124M.copy()%0A%20%20%20%20NEW_CONFIG.update(model_configs%5Bmodel_name%5D)%0A%20%20%20%20NEW_CONFIG.update(%7B%22context_length%22%3A%201024%2C%20%22qkv_bias%22%3A%20True%7D)%0A%0A%20%20%20%20gpt%20%3D%20GPTModel(NEW_CONFIG)%0A%20%20%20%20gpt.eval()%3B%0A%20%20%20%20return%20NEW_CONFIG%2C%20gpt%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20a%20function%20to%20load%20numerics%20as%20model%20weights%20with%20sanity%20checks.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20assign(left%2C%20right)%3A%0A%20%20%20%20if%20left.shape%20!%3D%20right.shape%3A%0A%20%20%20%20%20%20%20%20raise%20ValueError(f%22Shape%20mismatch.%20Left%3A%20%7Bleft.shape%7D%2C%20Right%3A%20%7Bright.shape%7D%22)%0A%20%20%20%20return%20torch.nn.Parameter(torch.tensor(right))%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20a%20loader%20function%20for%20GPT2%20weights%20by%20using%20the%20above.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20load_weights_into_gpt(gpt%2C%20params)%3A%0A%20%20%20%20%23%20positional%20and%20token%20embeddings%0A%20%20%20%20gpt.pos_emb.weight%20%3D%20assign(gpt.pos_emb.weight%2C%20params%5B%22wpe%22%5D)%0A%20%20%20%20gpt.tok_emb.weight%20%3D%20assign(gpt.tok_emb.weight%2C%20params%5B%22wte%22%5D)%0A%0A%20%20%20%20%23%20transformer%20blocks%0A%20%20%20%20for%20b%20in%20range(len(params%5B%22blocks%22%5D))%3A%0A%20%20%20%20%20%20%20%20%23%20multi-head%20attention%20-%3E%20linear%20projection%20(matrix)%0A%20%20%20%20%20%20%20%20q_w%2C%20k_w%2C%20v_w%20%3D%20np.split(%0A%20%20%20%20%20%20%20%20%20%20%20%20(params%5B%22blocks%22%5D%5Bb%5D%5B%22attn%22%5D%5B%22c_attn%22%5D)%5B%22w%22%5D%2C%203%2C%20axis%3D-1%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_query.weight%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_query.weight%2C%20q_w.T%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_key.weight%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_key.weight%2C%20k_w.T%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_value.weight%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_value.weight%2C%20v_w.T%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%23%20multi-head%20attention%20-%3E%20linear%20projection%20(bias)%0A%20%20%20%20%20%20%20%20q_b%2C%20k_b%2C%20v_b%20%3D%20np.split(%0A%20%20%20%20%20%20%20%20%20%20%20%20(params%5B%22blocks%22%5D%5Bb%5D%5B%22attn%22%5D%5B%22c_attn%22%5D)%5B%22b%22%5D%2C%203%2C%20axis%3D-1%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_query.bias%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_query.bias%2C%20q_b%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_key.bias%20%3D%20assign(gpt.trf_blocks%5Bb%5D.att.W_key.bias%2C%20k_b)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_value.bias%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.W_value.bias%2C%20v_b%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%23%20multi-head%20attention%20-%3E%20output%20projection%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.out_proj.weight%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.out_proj.weight%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20params%5B%22blocks%22%5D%5Bb%5D%5B%22attn%22%5D%5B%22c_proj%22%5D%5B%22w%22%5D.T%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.out_proj.bias%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.att.out_proj.bias%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20params%5B%22blocks%22%5D%5Bb%5D%5B%22attn%22%5D%5B%22c_proj%22%5D%5B%22b%22%5D%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%23%20feed-forward%20network%20(Linear%20-%3E%20GELU%20-%3E%20Linear)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.ff.layers%5B0%5D.weight%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.ff.layers%5B0%5D.weight%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20params%5B%22blocks%22%5D%5Bb%5D%5B%22mlp%22%5D%5B%22c_fc%22%5D%5B%22w%22%5D.T%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.ff.layers%5B0%5D.bias%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.ff.layers%5B0%5D.bias%2C%20params%5B%22blocks%22%5D%5Bb%5D%5B%22mlp%22%5D%5B%22c_fc%22%5D%5B%22b%22%5D%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.ff.layers%5B2%5D.weight%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.ff.layers%5B2%5D.weight%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20params%5B%22blocks%22%5D%5Bb%5D%5B%22mlp%22%5D%5B%22c_proj%22%5D%5B%22w%22%5D.T%2C%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.ff.layers%5B2%5D.bias%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.ff.layers%5B2%5D.bias%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20params%5B%22blocks%22%5D%5Bb%5D%5B%22mlp%22%5D%5B%22c_proj%22%5D%5B%22b%22%5D%2C%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20%20%20%20%20%23%20layer%20normalization%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.norm1.scale%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.norm1.scale%2C%20params%5B%22blocks%22%5D%5Bb%5D%5B%22ln_1%22%5D%5B%22g%22%5D%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.norm1.shift%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.norm1.shift%2C%20params%5B%22blocks%22%5D%5Bb%5D%5B%22ln_1%22%5D%5B%22b%22%5D%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.norm2.scale%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.norm2.scale%2C%20params%5B%22blocks%22%5D%5Bb%5D%5B%22ln_2%22%5D%5B%22g%22%5D%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.norm2.shift%20%3D%20assign(%0A%20%20%20%20%20%20%20%20%20%20%20%20gpt.trf_blocks%5Bb%5D.norm2.shift%2C%20params%5B%22blocks%22%5D%5Bb%5D%5B%22ln_2%22%5D%5B%22b%22%5D%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20gpt.final_norm.scale%20%3D%20assign(gpt.final_norm.scale%2C%20params%5B%22g%22%5D)%0A%20%20%20%20gpt.final_norm.shift%20%3D%20assign(gpt.final_norm.shift%2C%20params%5B%22b%22%5D)%0A%20%20%20%20gpt.out_head.weight%20%3D%20assign(gpt.out_head.weight%2C%20params%5B%22wte%22%5D)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Then%2C%20load%20it.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(device%2C%20gpt%2C%20params)%3A%0A%20%20%20%20load_weights_into_gpt(gpt%2C%20params)%0A%20%20%20%20gpt.to(device)%3B%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20It%20outputs%20rational%20sentence.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(NEW_CONFIG%2C%20device%2C%20gpt%2C%20tokenizer)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20_token_ids%20%3D%20generate(%0A%20%20%20%20%20%20%20%20model%3Dgpt%2C%0A%20%20%20%20%20%20%20%20idx%3Dtext_to_token_ids(%22Every%20effort%20moves%20you%22%2C%20tokenizer).to(device)%2C%0A%20%20%20%20%20%20%20%20max_new_tokens%3D25%2C%0A%20%20%20%20%20%20%20%20context_size%3DNEW_CONFIG%5B%22context_length%22%5D%2C%0A%20%20%20%20%20%20%20%20top_k%3D50%2C%0A%20%20%20%20%20%20%20%20temperature%3D1.5%0A%20%20%20%20)%0A%0A%20%20%20%20print(%22Output%20text%3A%5Cn%22%2C%20token_ids_to_text(_token_ids%2C%20tokenizer))%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
75d4f685fad17622ec9ff5e9141e7dc5