import%20marimo%0A%0A__generated_with%20%3D%20%220.18.4%22%0Aapp%20%3D%20marimo.App()%0A%0Awith%20app.setup%3A%0A%20%20%20%20import%20copy%0A%20%20%20%20import%20os%0A%20%20%20%20import%20time%0A%20%20%20%20import%20zipfile%0A%20%20%20%20from%20importlib.metadata%20import%20version%0A%20%20%20%20from%20pathlib%20import%20Path%0A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20import%20matplotlib.pyplot%20as%20plt%0A%20%20%20%20import%20pandas%20as%20pd%0A%20%20%20%20import%20requests%0A%20%20%20%20import%20tiktoken%0A%20%20%20%20import%20torch%0A%20%20%20%20%23%20previous%20notebooks%0A%20%20%20%20from%20ch04%20import%20GPTModel%2C%20generate_text_simple%0A%20%20%20%20from%20ch05%20import%20(load_weights_into_gpt%2C%20text_to_token_ids%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20token_ids_to_text)%0A%20%20%20%20%23%20additional%20utility%0A%20%20%20%20from%20gpt_download%20import%20download_and_load_gpt2%0A%20%20%20%20from%20torch.utils.data%20import%20DataLoader%2C%20Dataset%0A%0A%20%20%20%20%23%20detect%20available%20device%20(CPU%20or%20GPU)%0A%20%20%20%20%23%20skip%20the%20condition%20for%20MPS%20devices%20for%20simplicity%0A%20%20%20%20device%20%3D%20torch.device(%22cuda%22%20if%20torch.cuda.is_available()%20else%20%22cpu%22)%0A%20%20%20%20print(%22Device%3A%22%2C%20device)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Show%20library%20versions%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20pkgs%20%3D%20%5B%22matplotlib%22%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%22numpy%22%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%22tiktoken%22%2C%20%0A%20%20%20%20%20%20%20%20%20%20%20%20%22torch%22%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%22tensorflow%22%20%23%20For%20OpenAI's%20pretrained%20weights%0A%20%20%20%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20for%20p%20in%20pkgs%3A%0A%20%20%20%20%20%20%20%20print(f%22%7Bp%7D%20version%3A%20%7Bversion(p)%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%206.2%20Preparing%20the%20dataset%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Set%20paths%20to%20download%20dataset%20for%20fine-tuning%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20dataset_url%20%3D%20%22https%3A%2F%2Farchive.ics.uci.edu%2Fstatic%2Fpublic%2F228%2Fsms%2Bspam%2Bcollection.zip%22%0A%20%20%20%20project_root%20%3D%20Path(__file__).parent.parent.parent%0A%20%20%20%20data_root%20%3D%20project_root%20%2F%20%22data%22%20%2F%20%22ch06%22%0A%20%20%20%20os.makedirs(data_root%2C%20exist_ok%3DTrue)%0A%0A%20%20%20%20zip_path%20%3D%20data_root%20%2F%20%22sms_spam_collection.zip%22%0A%20%20%20%20extracted_path%20%3D%20data_root%20%2F%20%22sms_spam_collection%22%0A%20%20%20%20data_file_path%20%3D%20extracted_path%20%2F%20%22SMSSpamCollection.tsv%22%0A%0A%20%20%20%20print(f%22Downloading%20dataset%20to%20%7Bzip_path%7D...%22)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20data_file_path%2C%0A%20%20%20%20%20%20%20%20data_root%2C%0A%20%20%20%20%20%20%20%20dataset_url%2C%0A%20%20%20%20%20%20%20%20extracted_path%2C%0A%20%20%20%20%20%20%20%20project_root%2C%0A%20%20%20%20%20%20%20%20zip_path%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Define%20a%20function%20to%20download%2C%20extract%2C%20and%20rename%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20download_and_unzip_spam_data(url%2C%20zip_path%2C%20extracted_path%2C%20data_file_path)%3A%0A%20%20%20%20if%20data_file_path.exists()%3A%0A%20%20%20%20%20%20%20%20print(f%22File%20%7Bdata_file_path%7D%20already%20exists.%20Skipping%20download.%22)%0A%20%20%20%20%20%20%20%20return%0A%0A%20%20%20%20%23%20Downloading%20the%20file%0A%20%20%20%20response%20%3D%20requests.get(url%2C%20stream%3DTrue%2C%20timeout%3D60)%0A%20%20%20%20response.raise_for_status()%0A%20%20%20%20with%20open(zip_path%2C%20%22wb%22)%20as%20out_file%3A%0A%20%20%20%20%20%20%20%20for%20chunk%20in%20response.iter_content(chunk_size%3D8192)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20chunk%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20out_file.write(chunk)%0A%0A%20%20%20%20%23%20Unzipping%20the%20file%0A%20%20%20%20with%20zipfile.ZipFile(zip_path%2C%20%22r%22)%20as%20zip_ref%3A%0A%20%20%20%20%20%20%20%20zip_ref.extractall(extracted_path)%0A%20%20%20%20os.remove(zip_path)%20%23%20cleanup%20zip%20file%0A%0A%20%20%20%20%23%20Add%20.tsv%20file%20extension%0A%20%20%20%20original_file_path%20%3D%20Path(extracted_path)%20%2F%20%22SMSSpamCollection%22%0A%20%20%20%20os.rename(original_file_path%2C%20data_file_path)%0A%20%20%20%20print(f%22File%20downloaded%20and%20saved%20as%20%7Bdata_file_path%7D%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Prepare%20dataset%20by%20using%20the%20above%20function%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(data_file_path%2C%20dataset_url%2C%20extracted_path%2C%20zip_path)%3A%0A%20%20%20%20try%3A%0A%20%20%20%20%20%20%20%20download_and_unzip_spam_data(dataset_url%2C%20zip_path%2C%20extracted_path%2C%20data_file_path)%0A%20%20%20%20except%20(requests.exceptions.RequestException%2C%20TimeoutError)%20as%20e%3A%0A%20%20%20%20%20%20%20%20print(f%22Primary%20URL%20failed%3A%20%7Be%7D.%20Trying%20backup%20URL...%22)%0A%20%20%20%20%20%20%20%20_backup_url%20%3D%20%22https%3A%2F%2Ff001.backblazeb2.com%2Ffile%2FLLMs-from-scratch%2Fsms%252Bspam%252Bcollection.zip%22%0A%20%20%20%20%20%20%20%20download_and_unzip_spam_data(_backup_url%2C%20zip_path%2C%20extracted_path%2C%20data_file_path)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20dataset%20has%20Tab%20separater%20without%20header.%20Each%20rows%20are%20the%20pair%20of%20spam%2Fham%20pair%20and%20corresponding%20text%20(message).%20The%20ham%20means%20not%20spam.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(data_file_path)%3A%0A%20%20%20%20df%20%3D%20pd.read_csv(data_file_path%2C%20sep%3D%22%5Ct%22%2C%20header%3DNone%2C%20names%3D%5B%22Label%22%2C%20%22Text%22%5D)%0A%20%20%20%20df%0A%20%20%20%20return%20(df%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Show%20the%20statistics.%20Spam%20messages%20are%20anomalous%20and%20its%20number%20is%20much%20smallear%20than%20ham%20messages.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(df)%3A%0A%20%20%20%20print(df%5B%22Label%22%5D.value_counts())%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Undersample%20to%20balance%20the%20number%20of%20spam%20and%20ham%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20create_balanced_dataset(df)%3A%0A%20%20%20%20%23%20Count%20the%20instances%20of%20%22spam%22%0A%20%20%20%20num_spam%20%3D%20df%5Bdf%5B%22Label%22%5D%20%3D%3D%20%22spam%22%5D.shape%5B0%5D%0A%0A%20%20%20%20%23%20Randomly%20sample%20%22ham%22%20instances%20to%20match%20the%20number%20of%20%22spam%22%20instances%0A%20%20%20%20ham_subset%20%3D%20df%5Bdf%5B%22Label%22%5D%20%3D%3D%20%22ham%22%5D.sample(num_spam%2C%20random_state%3D123)%0A%0A%20%20%20%20%23%20Combine%20ham%20%22subset%22%20with%20%22spam%22%0A%20%20%20%20balanced_df%20%3D%20pd.concat(%5Bham_subset%2C%20df%5Bdf%5B%22Label%22%5D%20%3D%3D%20%22spam%22%5D%5D)%0A%0A%20%20%20%20return%20balanced_df%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Execute%20and%20check%20the%20result%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(df)%3A%0A%20%20%20%20balanced_df%20%3D%20create_balanced_dataset(df)%0A%20%20%20%20print(balanced_df%5B%22Label%22%5D.value_counts())%0A%20%20%20%20return%20(balanced_df%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Mapping%20%60str%60%20to%20%60int%60%20for%20numerical%20operations%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(balanced_df)%3A%0A%20%20%20%20balanced_df%5B%22Label%22%5D%20%3D%20balanced_df%5B%22Label%22%5D.map(%7B%22ham%22%3A%200%2C%20%22spam%22%3A%201%7D)%20%20%20%20%0A%20%20%20%20balanced_df%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Create%20splits%20for%20training%2C%20validation%2C%20and%20testing.%0A%20%20%20%20For%20%5Bpandas.DataFrame.sample%5D(https%3A%2F%2Fpandas.pydata.org%2Fdocs%2Freference%2Fapi%2Fpandas.DataFrame.sample.html)%2C%20the%20%60frac%3D1%60%20means%20100%25%20sampling%2C%20and%20the%20%60.reset_index(drop%3DTrue)%60%20means%20the%20operation%20delete%20the%20original%20index%20(row%20number).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20random_split(df%2C%20train_frac%2C%20validation_frac)%3A%0A%20%20%20%20%22%22%22%0A%20%20%20%20Randomly%20split%20a%20DataFrame%20into%20training%2C%20validation%2C%20and%20test%20sets.%0A%0A%20%20%20%20%3Aparam%20df%3A%20The%20DataFrame%20to%20split.%0A%20%20%20%20%3Aparam%20train_frac%3A%20Fraction%20(0-1)%20of%20data%20to%20use%20for%20training.%0A%20%20%20%20%3Aparam%20validation_frac%3A%20Fraction%20(0-1)%20of%20data%20to%20use%20for%20validation.%0A%20%20%20%20%3Areturn%3A%20A%20tuple%20of%20(train_df%2C%20validation_df%2C%20test_df).%0A%20%20%20%20%22%22%22%0A%20%20%20%20%23%20Shuffle%20the%20entire%20DataFrame%0A%20%20%20%20df%20%3D%20df.sample(frac%3D1%2C%20random_state%3D123).reset_index(drop%3DTrue)%0A%0A%20%20%20%20%23%20Calculate%20split%20indices%0A%20%20%20%20train_end%20%3D%20int(len(df)%20*%20train_frac)%0A%20%20%20%20validation_end%20%3D%20train_end%20%2B%20int(len(df)%20*%20validation_frac)%0A%0A%20%20%20%20%23%20Split%20the%20DataFrame%0A%20%20%20%20train_df%20%3D%20df%5B%3Atrain_end%5D%0A%20%20%20%20validation_df%20%3D%20df%5Btrain_end%3Avalidation_end%5D%0A%20%20%20%20test_df%20%3D%20df%5Bvalidation_end%3A%5D%0A%0A%20%20%20%20return%20train_df%2C%20validation_df%2C%20test_df%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Create%20the%20splits%20with%20the%20ratio%20for%0A%0A%20%20%20%20%24%24%0A%20%20%20%20%5Ctext%7Btrain%7D%20%3A%20%5Ctext%7Bvalid%7D%20%3A%20%5Ctext%7Btest%7D%20%3D%200.7%20%3A%200.1%20%3A%200.2%0A%20%20%20%20%24%24%0A%0A%20%20%20%20and%20save%20these%20dataframes%20as%20CSV%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(balanced_df%2C%20data_root)%3A%0A%20%20%20%20train_df%2C%20validation_df%2C%20test_df%20%3D%20random_split(balanced_df%2C%200.7%2C%200.1)%0A%20%20%20%20%23%20Test%20size%20is%20implied%20to%20be%200.2%20as%20the%20remainder%0A%0A%20%20%20%20train_path%20%3D%20data_root%20%2F%20%22train.csv%22%0A%20%20%20%20valid_path%20%3D%20data_root%20%2F%20%22validation.csv%22%0A%20%20%20%20test_path%20%3D%20data_root%20%2F%20%22test.csv%22%0A%20%20%20%20train_df.to_csv(train_path%2C%20index%3DNone)%0A%20%20%20%20validation_df.to_csv(valid_path%2C%20index%3DNone)%0A%20%20%20%20test_df.to_csv(test_path%2C%20index%3DNone)%0A%20%20%20%20return%20test_path%2C%20train_path%2C%20valid_path%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%206.3%20Creating%20data%20loaders%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Get%20tiktoken%20encoder%20and%20the%20token%20ID%20for%20%60%3C%7Cendoftext%7C%3E%60%20to%20use%20it%20as%20padding%20token%20ID%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20tokenizer%20%3D%20tiktoken.get_encoding(%22gpt2%22)%0A%20%20%20%20print(tokenizer.encode(%22%3C%7Cendoftext%7C%3E%22%2C%20allowed_special%3D%7B%22%3C%7Cendoftext%7C%3E%22%7D))%0A%20%20%20%20return%20(tokenizer%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Define%20%5BMap-style%20datasets%5D(https%3A%2F%2Fdocs.pytorch.org%2Fdocs%2Fstable%2Fdata.html%23map-style-datasets)%20for%20use%20with%20a%20DataLoader.%0A%20%20%20%20Token%20sequences%20are%20padded%20to%20a%20uniform%20length%2C%20either%20specified%20explicitly%20or%20set%20to%20the%20maximum%20sequence%20length%20in%20the%20dataset.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.class_definition%0Aclass%20SpamDataset(Dataset)%3A%0A%20%20%20%20def%20__init__(self%2C%20csv_file%2C%20tokenizer%2C%20max_length%3DNone%2C%20pad_token_id%3D50256)%3A%0A%20%20%20%20%20%20%20%20self.data%20%3D%20pd.read_csv(csv_file)%0A%0A%20%20%20%20%20%20%20%20%23%20Pre-tokenize%20texts%0A%20%20%20%20%20%20%20%20self.encoded_texts%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20tokenizer.encode(text)%20for%20text%20in%20self.data%5B%22Text%22%5D%0A%20%20%20%20%20%20%20%20%5D%0A%0A%20%20%20%20%20%20%20%20%23%20set%20sequence%20length%20to%20pad%0A%20%20%20%20%20%20%20%20if%20max_length%20is%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.max_length%20%3D%20self._longest_encoded_length()%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.max_length%20%3D%20max_length%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Truncate%20sequences%20if%20they%20are%20longer%20than%20max_length%0A%20%20%20%20%20%20%20%20%20%20%20%20self.encoded_texts%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20encoded_text%5B%3Aself.max_length%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20for%20encoded_text%20in%20self.encoded_texts%0A%20%20%20%20%20%20%20%20%20%20%20%20%5D%0A%0A%20%20%20%20%20%20%20%20%23%20Pad%20sequences%20to%20the%20longest%20sequence%0A%20%20%20%20%20%20%20%20self.encoded_texts%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20encoded_text%20%2B%20%5Bpad_token_id%5D%20*%20(self.max_length%20-%20len(encoded_text))%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20encoded_text%20in%20self.encoded_texts%0A%20%20%20%20%20%20%20%20%5D%0A%0A%20%20%20%20def%20__getitem__(self%2C%20index)%3A%0A%20%20%20%20%20%20%20%20encoded%20%3D%20self.encoded_texts%5Bindex%5D%0A%20%20%20%20%20%20%20%20label%20%3D%20self.data.iloc%5Bindex%5D%5B%22Label%22%5D%0A%20%20%20%20%20%20%20%20%23%20torch.long%20%3D%3D%20torch.int64%0A%20%20%20%20%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20%20%20%20%20torch.tensor(encoded%2C%20dtype%3Dtorch.long)%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20torch.tensor(label%2C%20dtype%3Dtorch.long)%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20def%20__len__(self)%3A%0A%20%20%20%20%20%20%20%20return%20len(self.data)%0A%0A%20%20%20%20def%20_longest_encoded_length(self)%3A%0A%20%20%20%20%20%20%20%20max_length%20%3D%200%0A%20%20%20%20%20%20%20%20for%20encoded_text%20in%20self.encoded_texts%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20encoded_length%20%3D%20len(encoded_text)%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20encoded_length%20%3E%20max_length%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20max_length%20%3D%20encoded_length%0A%20%20%20%20%20%20%20%20return%20max_length%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Construct%20a%20dataset%20for%20training.%20We%20can%20set%20the%20%60max_length%60%20to%20our%20context%20length%20to%20ensure%20it.%0A%20%20%20%20After%20the%20creation%2C%20check%20the%20%60max_length%60%20and%20the%20shape%20of%20each%20data.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(tokenizer%2C%20train_path)%3A%0A%20%20%20%20train_dataset%20%3D%20SpamDataset(%0A%20%20%20%20%20%20%20%20csv_file%3Dtrain_path%2C%0A%20%20%20%20%20%20%20%20max_length%3DNone%2C%20%20%23%20use%20the%20max%20length%20of%20actual%20sequences%0A%20%20%20%20%20%20%20%20tokenizer%3Dtokenizer%0A%20%20%20%20)%0A%0A%20%20%20%20print(f%22%7Btrain_dataset.max_length%3D%7D%22)%0A%20%20%20%20%23%20a%20tuple%20of%20seqeuence%20and%20its%20label%0A%20%20%20%20print(f%22%7Btrain_dataset%5B0%5D%3D%7D%22)%20%20%0A%20%20%20%20%23%20the%20length%20of%20sequence%20is%20identical%20with%20the%20max_length%0A%20%20%20%20print(f%22%7Btrain_dataset%5B0%5D%5B0%5D.shape%3D%7D%22)%0A%20%20%20%20return%20(train_dataset%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Create%20validation%20and%20test%20datasets%20also.%20We%20share%20the%20%60max_length%60%20for%20each%20construction%20to%20ensure%20the%20same%20condition.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(test_path%2C%20tokenizer%2C%20train_dataset%2C%20valid_path)%3A%0A%20%20%20%20val_dataset%20%3D%20SpamDataset(%0A%20%20%20%20%20%20%20%20csv_file%3Dvalid_path%2C%0A%20%20%20%20%20%20%20%20max_length%3Dtrain_dataset.max_length%2C%0A%20%20%20%20%20%20%20%20tokenizer%3Dtokenizer%0A%20%20%20%20)%0A%20%20%20%20test_dataset%20%3D%20SpamDataset(%0A%20%20%20%20%20%20%20%20csv_file%3Dtest_path%2C%0A%20%20%20%20%20%20%20%20max_length%3Dtrain_dataset.max_length%2C%0A%20%20%20%20%20%20%20%20tokenizer%3Dtokenizer%0A%20%20%20%20)%0A%20%20%20%20return%20test_dataset%2C%20val_dataset%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Construct%20%5Bdaloaders%5D(https%3A%2F%2Fdocs.pytorch.org%2Fdocs%2Fstable%2Fdata.html%23torch.utils.data.DataLoader)%20by%20using%20the%20datasets%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(test_dataset%2C%20train_dataset%2C%20val_dataset)%3A%0A%20%20%20%20num_workers%20%3D%200%20%20%23%20ensure%20compatibility%20(for%20all%20CPUs)%0A%20%20%20%20batch_size%20%3D%208%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20train_loader%20%3D%20DataLoader(%0A%20%20%20%20%20%20%20%20dataset%3Dtrain_dataset%2C%0A%20%20%20%20%20%20%20%20batch_size%3Dbatch_size%2C%0A%20%20%20%20%20%20%20%20shuffle%3DTrue%2C%20%20%23%20randomize%0A%20%20%20%20%20%20%20%20num_workers%3Dnum_workers%2C%0A%20%20%20%20%20%20%20%20drop_last%3DTrue%2C%20%20%23%20ignore%20incomplete%20batch%0A%20%20%20%20)%0A%0A%20%20%20%20val_loader%20%3D%20DataLoader(%0A%20%20%20%20%20%20%20%20dataset%3Dval_dataset%2C%0A%20%20%20%20%20%20%20%20batch_size%3Dbatch_size%2C%0A%20%20%20%20%20%20%20%20num_workers%3Dnum_workers%2C%0A%20%20%20%20%20%20%20%20drop_last%3DFalse%2C%0A%20%20%20%20)%0A%0A%20%20%20%20test_loader%20%3D%20DataLoader(%0A%20%20%20%20%20%20%20%20dataset%3Dtest_dataset%2C%0A%20%20%20%20%20%20%20%20batch_size%3Dbatch_size%2C%0A%20%20%20%20%20%20%20%20num_workers%3Dnum_workers%2C%0A%20%20%20%20%20%20%20%20drop_last%3DFalse%2C%0A%20%20%20%20)%0A%20%20%20%20return%20test_loader%2C%20train_loader%2C%20val_loader%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Iterate%20entirely%20for%20training%20dataloader%20to%20get%20the%20final%20batch%20and%20check%20the%20shape.%20We%20can%20see%20surely%20the%20final%20batch%20is%20complete%20shape.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(train_loader)%3A%0A%20%20%20%20%23%20iteration%20for%20doing%20nothing%0A%20%20%20%20for%20input_batch%2C%20target_batch%20in%20train_loader%3A%0A%20%20%20%20%20%20%20%20pass%0A%0A%20%20%20%20%23%20we%20expect%208%20pairs%20in%20a%20batch%20even%20it%20is%20the%20final%20batch%0A%20%20%20%20print(%22Input%20batch%20dimensions%3A%22%2C%20input_batch.shape)%0A%20%20%20%20print(%22Label%20batch%20dimensions%22%2C%20target_batch.shape)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Check%20the%20number%20of%20batches%20for%20each%20dataloaders%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(test_loader%2C%20train_loader%2C%20val_loader)%3A%0A%20%20%20%20_num_train%20%3D%20len(train_loader.dataset)%0A%20%20%20%20_num_val%20%3D%20len(val_loader.dataset)%0A%20%20%20%20_num_test%20%3D%20len(test_loader.dataset)%0A%20%20%20%20print(f%22%7B_num_train%7D%20training%20batches%22)%0A%20%20%20%20print(f%22%7B_num_val%7D%20validation%20batches%22)%0A%20%20%20%20print(f%22%7B_num_test%7D%20test%20batches%22)%0A%0A%20%20%20%20%23%20check%20the%20ratio%20is%20as%20expected%20or%20not%0A%20%20%20%20_num_all%20%3D%20_num_train%20%2B%20_num_val%20%2B%20_num_test%0A%20%20%20%20_train_ratio%20%3D%20_num_train%20%2F%20_num_all%0A%20%20%20%20_val_ratio%20%3D%20_num_val%20%2F%20_num_all%0A%20%20%20%20_test_ratio%20%3D%20_num_test%20%2F%20_num_all%0A%20%20%20%20print(f%22Train%3AVal%3ATest%20%3D%20%7B_train_ratio%3A.2f%7D%3A%7B_val_ratio%3A.2f%7D%3A%7B_test_ratio%3A.2f%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%206.4%20Initializing%20a%20model%20with%20pretrained%20weights%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Use%20the%20same%20configuration%20of%20the%20model%20with%20the%20pretraining%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(train_dataset)%3A%0A%20%20%20%20CHOOSE_MODEL%20%3D%20%22gpt2-small%20(124M)%22%0A%20%20%20%20INPUT_PROMPT%20%3D%20%22Every%20effort%20moves%22%0A%0A%20%20%20%20BASE_CONFIG%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22vocab_size%22%3A%2050257%2C%20%20%20%20%20%23%20Vocabulary%20size%0A%20%20%20%20%20%20%20%20%22context_length%22%3A%201024%2C%20%20%23%20Context%20length%0A%20%20%20%20%20%20%20%20%22drop_rate%22%3A%200.0%2C%20%20%20%20%20%20%20%20%23%20Dropout%20rate%0A%20%20%20%20%20%20%20%20%22qkv_bias%22%3A%20True%20%20%20%20%20%20%20%20%20%23%20Query-key-value%20bias%0A%20%20%20%20%7D%0A%0A%20%20%20%20model_configs%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22gpt2-small%20(124M)%22%3A%20%7B%22emb_dim%22%3A%20768%2C%20%22n_layers%22%3A%2012%2C%20%22n_heads%22%3A%2012%7D%2C%0A%20%20%20%20%20%20%20%20%22gpt2-medium%20(355M)%22%3A%20%7B%22emb_dim%22%3A%201024%2C%20%22n_layers%22%3A%2024%2C%20%22n_heads%22%3A%2016%7D%2C%0A%20%20%20%20%20%20%20%20%22gpt2-large%20(774M)%22%3A%20%7B%22emb_dim%22%3A%201280%2C%20%22n_layers%22%3A%2036%2C%20%22n_heads%22%3A%2020%7D%2C%0A%20%20%20%20%20%20%20%20%22gpt2-xl%20(1558M)%22%3A%20%7B%22emb_dim%22%3A%201600%2C%20%22n_layers%22%3A%2048%2C%20%22n_heads%22%3A%2025%7D%2C%0A%20%20%20%20%7D%0A%0A%20%20%20%20BASE_CONFIG.update(model_configs%5BCHOOSE_MODEL%5D)%0A%0A%20%20%20%20assert%20train_dataset.max_length%20%3C%3D%20BASE_CONFIG%5B%22context_length%22%5D%2C%20(%0A%20%20%20%20%20%20%20%20f%22Dataset%20length%20%7Btrain_dataset.max_length%7D%20exceeds%20model's%20context%20%22%0A%20%20%20%20%20%20%20%20f%22length%20%7BBASE_CONFIG%5B'context_length'%5D%7D.%20Reinitialize%20data%20sets%20with%20%22%0A%20%20%20%20%20%20%20%20f%22%60max_length%3D%7BBASE_CONFIG%5B'context_length'%5D%7D%60%22%0A%20%20%20%20)%0A%20%20%20%20return%20BASE_CONFIG%2C%20CHOOSE_MODEL%0A%0A%0A%40app.cell%0Adef%20_(BASE_CONFIG%2C%20CHOOSE_MODEL%2C%20project_root)%3A%0A%20%20%20%20%23%20just%20parse%20the%20%60str%60%20inside%20the%20brace%0A%20%20%20%20model_size%20%3D%20CHOOSE_MODEL.split(%22%20%22)%5B-1%5D.lstrip(%22(%22).rstrip(%22)%22)%0A%0A%20%20%20%20models_dir%20%3D%20project_root%20%2F%20%22models%22%20%2F%20%22gpt2%22%0A%20%20%20%20os.makedirs(models_dir%2C%20exist_ok%3DTrue)%0A%0A%20%20%20%20settings%2C%20params%20%3D%20download_and_load_gpt2(model_size%3Dmodel_size%2C%20models_dir%3Dmodels_dir)%0A%0A%20%20%20%20pretrained_model%20%3D%20GPTModel(BASE_CONFIG)%0A%20%20%20%20load_weights_into_gpt(pretrained_model%2C%20params)%0A%0A%20%20%20%20%23%20This%20will%20be%20finetuned%0A%20%20%20%20model%20%3D%20copy.deepcopy(pretrained_model)%0A%20%20%20%20return%20model%2C%20pretrained_model%0A%0A%0A%40app.cell%0Adef%20_(model)%3A%0A%20%20%20%20model.eval()%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Check%20the%20model%20is%20pretrained%20and%20its%20inference%20is%20proper%20by%20using%20a%20simple%20sequence%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(BASE_CONFIG%2C%20pretrained_model%2C%20tokenizer)%3A%0A%20%20%20%20_text_1%20%3D%20%22Every%20effort%20moves%20you%22%0A%0A%20%20%20%20_token_ids%20%3D%20generate_text_simple(%0A%20%20%20%20%20%20%20%20model%3Dpretrained_model%2C%0A%20%20%20%20%20%20%20%20idx%3Dtext_to_token_ids(_text_1%2C%20tokenizer)%2C%0A%20%20%20%20%20%20%20%20max_new_tokens%3D15%2C%0A%20%20%20%20%20%20%20%20context_size%3DBASE_CONFIG%5B%22context_length%22%5D%0A%20%20%20%20)%0A%20%20%20%20print(f%22%7B_token_ids%3D%7D%22)%0A%0A%20%20%20%20print(%22token_ids_to_text(_token_ids%2C%20tokenizer)%3D%22)%0A%20%20%20%20print(f%22%7Btoken_ids_to_text(_token_ids%2C%20tokenizer)%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20See%20the%20behavior%20before%20the%20fine-tuning.%20The%20answer%20does%20not%20follow%20the%20initial%20instructions%2C%20and%20answer%20the%20question%20without%20yes%20or%20no.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(BASE_CONFIG%2C%20pretrained_model%2C%20tokenizer)%3A%0A%20%20%20%20_text_2%20%3D%20(%0A%20%20%20%20%20%20%20%20%22Is%20the%20following%20text%20'spam'%3F%20Answer%20with%20'yes'%20or%20'no'%3A%22%0A%20%20%20%20%20%20%20%20%22%20'You%20are%20a%20winner%20you%20have%20been%20specially%22%0A%20%20%20%20%20%20%20%20%22%20selected%20to%20receive%20%241000%20cash%20or%20a%20%242000%20award.'%22%0A%20%20%20%20)%0A%0A%20%20%20%20_token_ids%20%3D%20generate_text_simple(%0A%20%20%20%20%20%20%20%20model%3Dpretrained_model%2C%0A%20%20%20%20%20%20%20%20idx%3Dtext_to_token_ids(_text_2%2C%20tokenizer)%2C%0A%20%20%20%20%20%20%20%20max_new_tokens%3D23%2C%0A%20%20%20%20%20%20%20%20context_size%3DBASE_CONFIG%5B%22context_length%22%5D%0A%20%20%20%20)%0A%0A%20%20%20%20print(token_ids_to_text(_token_ids%2C%20tokenizer))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%206.5%20Adding%20a%20classification%20head%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Freeze%20all%20layers%20initially%20to%20reuse%20the%20most%20of%20weights%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model)%3A%0A%20%20%20%20for%20_param%20in%20model.parameters()%3A%0A%20%20%20%20%20%20%20%20_param.requires_grad%20%3D%20False%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Define%20new%20linear%20layer%20to%20classify%20spam%20or%20ham%2C%20and%20replace%20the%20current%20last%20layer%20to%20output%20expected%20tokens%20with%20the%20new%20linear%20layer%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(BASE_CONFIG%2C%20model)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20num_classes%20%3D%202%0A%20%20%20%20model.out_head%20%3D%20torch.nn.Linear(in_features%3DBASE_CONFIG%5B%22emb_dim%22%5D%2C%20out_features%3Dnum_classes)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Allow%20gradient%20calculations%20for%20the%20transformer%20blocks%20and%20the%20last%20%60LayerNorm%60%20module%20to%20train%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model)%3A%0A%20%20%20%20for%20_param%20in%20model.trf_blocks%5B-1%5D.parameters()%3A%0A%20%20%20%20%20%20%20%20_param.requires_grad%20%3D%20True%0A%0A%20%20%20%20for%20_param%20in%20model.final_norm.parameters()%3A%0A%20%20%20%20%20%20%20%20_param.requires_grad%20%3D%20True%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Try%20inference%20with%20this%20architecture%20before%20fine-tuning%20by%20using%20following%20inputs%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(tokenizer)%3A%0A%20%20%20%20inputs%20%3D%20tokenizer.encode(%22Do%20you%20have%20time%22)%0A%20%20%20%20inputs%20%3D%20torch.tensor(inputs).unsqueeze(0)%0A%20%20%20%20print(%22Inputs%3A%22%2C%20inputs)%0A%20%20%20%20print(%22Inputs%20dimensions%3A%22%2C%20inputs.shape)%20%23%20shape%3A%20(batch_size%2C%20num_tokens)%0A%20%20%20%20return%20(inputs%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Inference%20with%20the%20input%20and%20check%20the%20output%20shape%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20output%20size%20is%20%24(%5Ctext%7Bbatch%20size%7D%2C%20%5Ctext%7Bnum%20tokens%7D%2C%20%5Ctext%7Bnum%20classes%7D)%24.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs%2C%20model)%3A%0A%20%20%20%20model.to(%22cpu%22)%20%20%23%20ensure%20device%20matching%0A%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20outputs%20%3D%20model(inputs)%0A%0A%20%20%20%20print(%22Outputs%3A%5Cn%22%2C%20outputs)%0A%20%20%20%20print(%22Outputs%20dimensions%3A%22%2C%20outputs.shape)%20%23%20shape%3A%20(batch_size%2C%20num_tokens%2C%20num_classes)%0A%20%20%20%20return%20(outputs%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20use%20the%20classification%20result%20for%20the%20last%20token%20only%20because%20this%20is%20the%20only%20token%20can%20have%20causal%20correlation%20with%20all%20tokens%20based%20on%20causal%20attention%20mask.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(outputs)%3A%0A%20%20%20%20print(%22Last%20output%20token%3A%22%2C%20outputs%5B%3A%2C%20-1%2C%20%3A%5D)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%206.6%20Calculating%20the%20classification%20loss%20and%20accuracy%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Map%20the%20output%20to%20class%20label%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(outputs)%3A%0A%20%20%20%20_probas%20%3D%20torch.softmax(outputs%5B%3A%2C%20-1%2C%20%3A%5D%2C%20dim%3D-1)%0A%20%20%20%20_label%20%3D%20torch.argmax(_probas)%0A%20%20%20%20print(%22Class%20label%3A%22%2C%20_label.item())%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20get%20the%20same%20result%20by%20just%20taking%20%60argmax%60%20only%20because%20%60softmax%60%20is%20monotonic%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(outputs)%3A%0A%20%20%20%20_logits%20%3D%20outputs%5B%3A%2C%20-1%2C%20%3A%5D%0A%20%20%20%20_label%20%3D%20torch.argmax(_logits)%0A%20%20%20%20print(%22Class%20label%3A%22%2C%20_label.item())%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Define%20the%20accuracy%20metric%20to%20evaluate%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20calc_accuracy_loader(data_loader%2C%20model%2C%20device%2C%20num_batches%3DNone)%3A%0A%20%20%20%20model.eval()%0A%20%20%20%20correct_predictions%2C%20num_examples%20%3D%200%2C%200%0A%0A%20%20%20%20if%20num_batches%20is%20None%3A%0A%20%20%20%20%20%20%20%20num_batches%20%3D%20len(data_loader)%0A%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20num_batches%20%3D%20min(num_batches%2C%20len(data_loader))%0A%0A%20%20%20%20for%20i%2C%20(input_batch%2C%20target_batch)%20in%20enumerate(data_loader)%3A%0A%20%20%20%20%20%20%20%20if%20i%20%3C%20num_batches%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20input_batch%2C%20target_batch%20%3D%20input_batch.to(device)%2C%20target_batch.to(device)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20Logits%20of%20last%20output%20token%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%23%20(num_samples%2C%20num_tokens%2C%20num_logits)%20-%3E%20%20(num_samples%2C%20num_logits)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20logits%20%3D%20model(input_batch)%5B%3A%2C%20-1%2C%20%3A%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20predicted_labels%20%3D%20torch.argmax(logits%2C%20dim%3D-1)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20num_examples%20%2B%3D%20predicted_labels.shape%5B0%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20correct_predictions%20%2B%3D%20(predicted_labels%20%3D%3D%20target_batch).sum().item()%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20break%0A%0A%20%20%20%20return%20correct_predictions%20%2F%20num_examples%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Compute%20initial%20accuracies%20without%20training%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model%2C%20test_loader%2C%20train_loader%2C%20val_loader)%3A%0A%20%20%20%20%23%20no%20assignment%20model%20%3D%20model.to(device)%20necessary%20for%20nn.Module%20classes%0A%20%20%20%20model.to(device)%0A%20%20%20%20%23%20For%20reproducibility%20due%20to%20the%20shuffling%20in%20the%20training%20data%20loader%0A%20%20%20%20torch.manual_seed(123)%20%0A%0A%20%20%20%20_train_accuracy%20%3D%20calc_accuracy_loader(train_loader%2C%20model%2C%20device%2C%20num_batches%3D10)%0A%20%20%20%20_val_accuracy%20%3D%20calc_accuracy_loader(val_loader%2C%20model%2C%20device%2C%20num_batches%3D10)%0A%20%20%20%20_test_accuracy%20%3D%20calc_accuracy_loader(test_loader%2C%20model%2C%20device%2C%20num_batches%3D10)%0A%0A%20%20%20%20print(f%22Training%20accuracy%3A%20%7B_train_accuracy*100%3A.2f%7D%25%22)%0A%20%20%20%20print(f%22Validation%20accuracy%3A%20%7B_val_accuracy*100%3A.2f%7D%25%22)%0A%20%20%20%20print(f%22Test%20accuracy%3A%20%7B_test_accuracy*100%3A.2f%7D%25%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Use%20cross%20entropy%20loss%20to%20train%20instead%20of%20the%20accuracy%20because%20of%20these%20differentiability.%20See%20%5Bthis%5D(https%3A%2F%2Fdocs.pytorch.org%2Fdocs%2Fstable%2Fgenerated%2Ftorch.nn.CrossEntropyLoss.html%23torch.nn.CrossEntropyLoss)%20for%20the%20definition.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20calc_loss_batch(input_batch%2C%20target_batch%2C%20model%2C%20device)%3A%0A%20%20%20%20input_batch%2C%20target_batch%20%3D%20input_batch.to(device)%2C%20target_batch.to(device)%0A%20%20%20%20logits%20%3D%20model(input_batch)%5B%3A%2C%20-1%2C%20%3A%5D%20%20%23%20Logits%20of%20last%20output%20token%0A%20%20%20%20loss%20%3D%20torch.nn.functional.cross_entropy(logits%2C%20target_batch)%0A%20%20%20%20return%20loss%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Compute%20average%20loss%20by%20using%20each%20dataloader%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0A%23%20Same%20as%20in%20chapter%205%0Adef%20calc_loss_loader(data_loader%2C%20model%2C%20device%2C%20num_batches%3DNone)%3A%0A%20%20%20%20if%20len(data_loader)%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20return%20float(%22nan%22)%0A%20%20%20%20elif%20num_batches%20is%20None%3A%0A%20%20%20%20%20%20%20%20num_batches%20%3D%20len(data_loader)%0A%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%23%20Reduce%20the%20number%20of%20batches%20to%20match%20the%20total%20number%20of%20batches%20in%20the%20data%20loader%0A%20%20%20%20%20%20%20%20%23%20if%20num_batches%20exceeds%20the%20number%20of%20batches%20in%20the%20data%20loader%0A%20%20%20%20%20%20%20%20num_batches%20%3D%20min(num_batches%2C%20len(data_loader))%0A%0A%20%20%20%20total_loss%20%3D%200.%0A%20%20%20%20for%20i%2C%20(input_batch%2C%20target_batch)%20in%20enumerate(data_loader)%3A%0A%20%20%20%20%20%20%20%20if%20i%20%3C%20num_batches%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20calc_loss_batch(input_batch%2C%20target_batch%2C%20model%2C%20device)%0A%20%20%20%20%20%20%20%20%20%20%20%20total_loss%20%2B%3D%20loss.item()%0A%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20break%0A%20%20%20%20return%20total_loss%20%2F%20num_batches%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Compute%20initial%20losses%20without%20training%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model%2C%20test_loader%2C%20train_loader%2C%20val_loader)%3A%0A%20%20%20%20model.to(device)%0A%20%20%20%20with%20torch.no_grad()%3A%20%23%20Disable%20gradient%20tracking%20for%20efficiency%20because%20we%20are%20not%20training%2C%20yet%0A%20%20%20%20%20%20%20%20_train_loss%20%3D%20calc_loss_loader(train_loader%2C%20model%2C%20device%2C%20num_batches%3D5)%0A%20%20%20%20%20%20%20%20_val_loss%20%3D%20calc_loss_loader(val_loader%2C%20model%2C%20device%2C%20num_batches%3D5)%0A%20%20%20%20%20%20%20%20_test_loss%20%3D%20calc_loss_loader(test_loader%2C%20model%2C%20device%2C%20num_batches%3D5)%0A%0A%20%20%20%20print(f%22Training%20loss%3A%20%7B_train_loss%3A.3f%7D%22)%0A%20%20%20%20print(f%22Validation%20loss%3A%20%7B_val_loss%3A.3f%7D%22)%0A%20%20%20%20print(f%22Test%20loss%3A%20%7B_test_loss%3A.3f%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%206.7%20Finetuning%20the%20model%20on%20supervised%20data%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Define%20fine-tuning%20trainer%20that%20calculates%20accuracies%20in%20the%20end%20of%20each%20epochs%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0A%23%20Overall%20the%20same%20as%20%60train_model_simple%60%20in%20chapter%205%0Adef%20train_classifier_simple(model%2C%20train_loader%2C%20val_loader%2C%20optimizer%2C%20device%2C%20num_epochs%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20eval_freq%2C%20eval_iter)%3A%0A%20%20%20%20%22%22%22%0A%20%20%20%20Train%20the%20model%20using%20the%20provided%20data%20loaders%20and%20optimizer.%0A%0A%20%20%20%20%3Aparam%20model%3A%20The%20model%20to%20train.%0A%20%20%20%20%3Aparam%20train_loader%3A%20DataLoader%20for%20training%20data.%0A%20%20%20%20%3Aparam%20val_loader%3A%20DataLoader%20for%20validation%20data.%0A%20%20%20%20%3Aparam%20optimizer%3A%20Optimizer%20for%20updating%20model%20weights.%0A%20%20%20%20%3Aparam%20device%3A%20Device%20to%20run%20the%20training%20on%20(e.g.%2C%20'cpu'%20or%20'cuda').%0A%20%20%20%20%3Aparam%20num_epochs%3A%20Number%20of%20epochs%20to%20train.%0A%20%20%20%20%3Aparam%20eval_freq%3A%20Frequency%20(in%20steps)%20to%20evaluate%20the%20model.%0A%20%20%20%20%3Aparam%20eval_iter%3A%20Number%20of%20batches%20to%20use%20for%20evaluation.%0A%20%20%20%20%3Areturn%3A%20Tuple%20of%20lists%20containing%20training%20losses%2C%20validation%20losses%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20training%20accuracies%2C%20validation%20accuracies%2C%20and%20total%20examples%20seen.%0A%20%20%20%20%22%22%22%0A%20%20%20%20%23%20Initialize%20lists%20to%20track%20losses%20and%20examples%20seen%0A%20%20%20%20train_losses%2C%20val_losses%2C%20train_accs%2C%20val_accs%20%3D%20%5B%5D%2C%20%5B%5D%2C%20%5B%5D%2C%20%5B%5D%0A%20%20%20%20examples_seen%2C%20global_step%20%3D%200%2C%20-1%0A%0A%20%20%20%20%23%20Main%20training%20loop%0A%20%20%20%20for%20epoch%20in%20range(num_epochs)%3A%0A%20%20%20%20%20%20%20%20model.train()%20%20%23%20Set%20model%20to%20training%20mode%0A%0A%20%20%20%20%20%20%20%20for%20input_batch%2C%20target_batch%20in%20train_loader%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20optimizer.zero_grad()%20%23%20Reset%20loss%20gradients%20from%20previous%20batch%20iteration%0A%20%20%20%20%20%20%20%20%20%20%20%20loss%20%3D%20calc_loss_batch(input_batch%2C%20target_batch%2C%20model%2C%20device)%0A%20%20%20%20%20%20%20%20%20%20%20%20loss.backward()%20%23%20Calculate%20loss%20gradients%0A%20%20%20%20%20%20%20%20%20%20%20%20optimizer.step()%20%23%20Update%20model%20weights%20using%20loss%20gradients%0A%20%20%20%20%20%20%20%20%20%20%20%20examples_seen%20%2B%3D%20input_batch.shape%5B0%5D%20%23%20New%3A%20track%20examples%20instead%20of%20tokens%0A%20%20%20%20%20%20%20%20%20%20%20%20global_step%20%2B%3D%201%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Optional%20evaluation%20step%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20global_step%20%25%20eval_freq%20%3D%3D%200%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20train_loss%2C%20val_loss%20%3D%20evaluate_model(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20model%2C%20train_loader%2C%20val_loader%2C%20device%2C%20eval_iter)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20train_losses.append(train_loss)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20val_losses.append(val_loss)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(f%22Ep%20%7Bepoch%2B1%7D%20(Step%20%7Bglobal_step%3A06d%7D)%3A%20%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20f%22Train%20loss%20%7Btrain_loss%3A.3f%7D%2C%20Val%20loss%20%7Bval_loss%3A.3f%7D%22)%0A%0A%20%20%20%20%20%20%20%20%23%20Calculate%20accuracy%20after%20each%20epoch%0A%20%20%20%20%20%20%20%20train_accuracy%20%3D%20calc_accuracy_loader(train_loader%2C%20model%2C%20device%2C%20num_batches%3Deval_iter)%0A%20%20%20%20%20%20%20%20val_accuracy%20%3D%20calc_accuracy_loader(val_loader%2C%20model%2C%20device%2C%20num_batches%3Deval_iter)%0A%20%20%20%20%20%20%20%20print(f%22Training%20accuracy%3A%20%7Btrain_accuracy*100%3A.2f%7D%25%20%7C%20%22%2C%20end%3D%22%22)%0A%20%20%20%20%20%20%20%20print(f%22Validation%20accuracy%3A%20%7Bval_accuracy*100%3A.2f%7D%25%22)%0A%20%20%20%20%20%20%20%20train_accs.append(train_accuracy)%0A%20%20%20%20%20%20%20%20val_accs.append(val_accuracy)%0A%0A%20%20%20%20return%20train_losses%2C%20val_losses%2C%20train_accs%2C%20val_accs%2C%20examples_seen%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Define%20fine-tuning%20evaluater%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0A%23%20Same%20as%20chapter%205%0Adef%20evaluate_model(model%2C%20train_loader%2C%20val_loader%2C%20device%2C%20eval_iter)%3A%0A%20%20%20%20model.eval()%0A%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20train_loss%20%3D%20calc_loss_loader(train_loader%2C%20model%2C%20device%2C%20num_batches%3Deval_iter)%0A%20%20%20%20%20%20%20%20val_loss%20%3D%20calc_loss_loader(val_loader%2C%20model%2C%20device%2C%20num_batches%3Deval_iter)%0A%0A%20%20%20%20model.train()%0A%20%20%20%20return%20train_loss%2C%20val_loss%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Execute%20the%20training%20in%205%20epochs%2C%20and%20the%20accuracy%20will%20be%20improved%20more%20than%2090%25.%20It%20takes%20about%203%20minuites%20for%20RTX3070.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model%2C%20train_loader%2C%20val_loader)%3A%0A%20%20%20%20_start_time%20%3D%20time.time()%0A%0A%20%20%20%20model.to(device)%0A%20%20%20%20torch.manual_seed(123)%0A%20%20%20%20_optimizer%20%3D%20torch.optim.AdamW(model.parameters()%2C%20lr%3D5e-5%2C%20weight_decay%3D0.1)%0A%20%20%20%20num_epochs%20%3D%205%0A%0A%20%20%20%20train_losses%2C%20val_losses%2C%20train_accs%2C%20val_accs%2C%20examples_seen%20%3D%20train_classifier_simple(%0A%20%20%20%20%20%20%20%20model%2C%20train_loader%2C%20val_loader%2C%20_optimizer%2C%20device%2C%0A%20%20%20%20%20%20%20%20num_epochs%3Dnum_epochs%2C%20eval_freq%3D50%2C%20eval_iter%3D5%2C%0A%20%20%20%20)%0A%0A%20%20%20%20end_time%20%3D%20time.time()%0A%20%20%20%20execution_time_minutes%20%3D%20(end_time%20-%20_start_time)%20%2F%2060%0A%20%20%20%20print(f%22Training%20completed%20in%20%7Bexecution_time_minutes%3A.2f%7D%20minutes.%22)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20examples_seen%2C%0A%20%20%20%20%20%20%20%20num_epochs%2C%0A%20%20%20%20%20%20%20%20train_accs%2C%0A%20%20%20%20%20%20%20%20train_losses%2C%0A%20%20%20%20%20%20%20%20val_accs%2C%0A%20%20%20%20%20%20%20%20val_losses%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Plot%20the%20losses%20and%20verify%20it%20works%20well%20from%20the%20fact%20that%20losses%20decrease%20rapidlly.%20Especially%2C%20validation%20losses%20looks%20like%20training%20losses%20and%20it%20shows%20the%20traning%20avoids%20overfitting.%20These%20facts%20justify%205%20epochs%20are%20enough%20for%20this%20fine-tuning.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(data_root)%3A%0A%20%20%20%20def%20plot_values(epochs_seen%2C%20examples_seen%2C%20train_values%2C%20val_values%2C%20label%3D%22loss%22)%3A%0A%20%20%20%20%20%20%20%20fig%2C%20ax1%20%3D%20plt.subplots(figsize%3D(5%2C%203))%0A%0A%20%20%20%20%20%20%20%20%23%20Plot%20training%20and%20validation%20loss%20against%20epochs%0A%20%20%20%20%20%20%20%20ax1.plot(epochs_seen%2C%20train_values%2C%20label%3Df%22Training%20%7Blabel%7D%22)%0A%20%20%20%20%20%20%20%20ax1.plot(epochs_seen%2C%20val_values%2C%20linestyle%3D%22-.%22%2C%20label%3Df%22Validation%20%7Blabel%7D%22)%0A%20%20%20%20%20%20%20%20ax1.set_xlabel(%22Epochs%22)%0A%20%20%20%20%20%20%20%20ax1.set_ylabel(label.capitalize())%0A%20%20%20%20%20%20%20%20ax1.legend()%0A%0A%20%20%20%20%20%20%20%20%23%20Create%20a%20second%20x-axis%20for%20examples%20seen%0A%20%20%20%20%20%20%20%20ax2%20%3D%20ax1.twiny()%20%20%23%20Create%20a%20second%20x-axis%20that%20shares%20the%20same%20y-axis%0A%20%20%20%20%20%20%20%20ax2.plot(examples_seen%2C%20train_values%2C%20alpha%3D0)%20%20%23%20Invisible%20plot%20for%20aligning%20ticks%0A%20%20%20%20%20%20%20%20ax2.set_xlabel(%22Examples%20seen%22)%0A%0A%20%20%20%20%20%20%20%20fig.tight_layout()%20%20%23%20Adjust%20layout%20to%20make%20room%0A%20%20%20%20%20%20%20%20plt.savefig(data_root%20%2F%20f%22%7Blabel%7D-plot.pdf%22)%0A%20%20%20%20%20%20%20%20plt.show()%0A%20%20%20%20return%20(plot_values%2C)%0A%0A%0A%40app.cell%0Adef%20_(examples_seen%2C%20num_epochs%2C%20plot_values%2C%20train_losses%2C%20val_losses)%3A%0A%20%20%20%20_epochs_tensor%20%3D%20torch.linspace(0%2C%20num_epochs%2C%20len(train_losses))%0A%20%20%20%20_examples_seen_tensor%20%3D%20torch.linspace(0%2C%20examples_seen%2C%20len(train_losses))%0A%0A%20%20%20%20plot_values(_epochs_tensor%2C%20_examples_seen_tensor%2C%20train_losses%2C%20val_losses)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Plot%20the%20accuracies%20also.%20We%20used%205%20bathes%20to%20evaluate%20these%20(see%20the%20%60eval_iter%60%20argument).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(examples_seen%2C%20num_epochs%2C%20plot_values%2C%20train_accs%2C%20val_accs)%3A%0A%20%20%20%20_epochs_tensor%20%3D%20torch.linspace(0%2C%20num_epochs%2C%20len(train_accs))%0A%20%20%20%20_examples_seen_tensor%20%3D%20torch.linspace(0%2C%20examples_seen%2C%20len(train_accs))%0A%0A%20%20%20%20plot_values(_epochs_tensor%2C%20_examples_seen_tensor%2C%20train_accs%2C%20val_accs%2C%20label%3D%22accuracy%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20results%20for%20the%20final%20model%20and%20entire%20dataloader%20show%20the%20accuracies%20are%20higher%20than%2090%25.%20The%20fact%20that%20the%20test%20accuracy%20is%20smaller%20than%20others%20shows%20small%20overfitting.%20This%20differences%20may%20be%20removed%20by%20hyper%20parameter%20tuning%20for%20%60drop_rate%60%2C%20%60weight_decay%60%2C%20and%20so%20on.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model%2C%20test_loader%2C%20train_loader%2C%20val_loader)%3A%0A%20%20%20%20_train_accuracy%20%3D%20calc_accuracy_loader(train_loader%2C%20model%2C%20device)%0A%20%20%20%20_val_accuracy%20%3D%20calc_accuracy_loader(val_loader%2C%20model%2C%20device)%0A%20%20%20%20_test_accuracy%20%3D%20calc_accuracy_loader(test_loader%2C%20model%2C%20device)%0A%0A%20%20%20%20print(f%22Training%20accuracy%3A%20%7B_train_accuracy*100%3A.2f%7D%25%22)%0A%20%20%20%20print(f%22Validation%20accuracy%3A%20%7B_val_accuracy*100%3A.2f%7D%25%22)%0A%20%20%20%20print(f%22Test%20accuracy%3A%20%7B_test_accuracy*100%3A.2f%7D%25%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%206.8%20Using%20the%20LLM%20as%20a%20spam%20classifier%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Define%20a%20function%20to%20do%20preprocessing%2C%20inference%2C%20and%20postprocessing%20to%20answer%20whether%20the%20input%20text%20is%20spam%20or%20ham%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20classify_review(text%2C%20model%2C%20tokenizer%2C%20device%2C%20max_length%3DNone%2C%20pad_token_id%3D50256)%3A%0A%20%20%20%20%22%22%22%0A%20%20%20%20Classify%20a%20given%20text%20as%20%22spam%22%20or%20%22not%20spam%22%20using%20the%20provided%20model%20and%20tokenizer.%0A%0A%20%20%20%20%3Aparam%20text%3A%20The%20input%20text%20to%20classify.%0A%20%20%20%20%3Aparam%20model%3A%20The%20trained%20classification%20model.%0A%20%20%20%20%3Aparam%20tokenizer%3A%20The%20tokenizer%20used%20to%20encode%20the%20text.%0A%20%20%20%20%3Aparam%20device%3A%20The%20device%20to%20run%20the%20model%20on%20(e.g.%2C%20%22cpu%22%20or%20%22cuda%22).%0A%20%20%20%20%3Aparam%20max_length%3A%20The%20maximum%20length%20for%20the%20input%20sequence.%20If%20None%2C%20use%20the%20model's%20context%20length.%0A%20%20%20%20%3Aparam%20pad_token_id%3A%20The%20token%20ID%20used%20for%20padding%20sequences.%0A%20%20%20%20%3Areturn%3A%20%22spam%22%20if%20the%20text%20is%20classified%20as%20spam%2C%20otherwise%20%22not%20spam%22.%0A%20%20%20%20%22%22%22%0A%20%20%20%20model.eval()%0A%0A%20%20%20%20%23%20Prepare%20inputs%20to%20the%20model%0A%20%20%20%20input_ids%20%3D%20tokenizer.encode(text)%0A%20%20%20%20supported_context_length%20%3D%20model.pos_emb.weight.shape%5B0%5D%0A%20%20%20%20%23%20Note%3A%20In%20the%20book%2C%20this%20was%20originally%20written%20as%20pos_emb.weight.shape%5B1%5D%20by%20mistake%0A%20%20%20%20%23%20It%20didn't%20break%20the%20code%20but%20would%20have%20caused%20unnecessary%20truncation%20(to%20768%20instead%20of%201024)%0A%0A%20%20%20%20%23%20Truncate%20sequences%20if%20they%20too%20long%0A%20%20%20%20input_ids%20%3D%20input_ids%5B%3Amin(max_length%2C%20supported_context_length)%5D%0A%20%20%20%20assert%20max_length%20is%20not%20None%2C%20(%0A%20%20%20%20%20%20%20%20%22max_length%20must%20be%20specified.%20If%20you%20want%20to%20use%20the%20full%20model%20context%2C%20%22%0A%20%20%20%20%20%20%20%20%22pass%20max_length%3Dmodel.pos_emb.weight.shape%5B0%5D.%22%0A%20%20%20%20)%0A%20%20%20%20assert%20max_length%20%3C%3D%20supported_context_length%2C%20(%0A%20%20%20%20%20%20%20%20f%22max_length%20(%7Bmax_length%7D)%20exceeds%20model's%20supported%20context%20length%20(%7Bsupported_context_length%7D).%22%0A%20%20%20%20)%20%20%20%20%0A%20%20%20%20%23%20Alternatively%2C%20a%20more%20robust%20version%20is%20the%20following%20one%2C%20which%20handles%20the%20max_length%3DNone%20case%20better%0A%20%20%20%20%23%20max_len%20%3D%20min(max_length%2Csupported_context_length)%20if%20max_length%20else%20supported_context_length%0A%20%20%20%20%23%20input_ids%20%3D%20input_ids%5B%3Amax_len%5D%0A%0A%20%20%20%20%23%20Pad%20sequences%20to%20the%20longest%20sequence%0A%20%20%20%20input_ids%20%2B%3D%20%5Bpad_token_id%5D%20*%20(max_length%20-%20len(input_ids))%0A%20%20%20%20input_tensor%20%3D%20torch.tensor(input_ids%2C%20device%3Ddevice).unsqueeze(0)%20%23%20add%20batch%20dimension%0A%0A%20%20%20%20%23%20Model%20inference%0A%20%20%20%20with%20torch.no_grad()%3A%0A%20%20%20%20%20%20%20%20logits%20%3D%20model(input_tensor)%5B%3A%2C%20-1%2C%20%3A%5D%20%20%23%20Logits%20of%20the%20last%20output%20token%0A%20%20%20%20predicted_label%20%3D%20torch.argmax(logits%2C%20dim%3D-1).item()%0A%0A%20%20%20%20%23%20Return%20the%20classified%20result%0A%20%20%20%20return%20%22spam%22%20if%20predicted_label%20%3D%3D%201%20else%20%22not%20spam%22%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20an%20example%20for%20spam%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model%2C%20tokenizer%2C%20train_dataset)%3A%0A%20%20%20%20_text_1%20%3D%20(%0A%20%20%20%20%20%20%20%20%22You%20are%20a%20winner%20you%20have%20been%20specially%22%0A%20%20%20%20%20%20%20%20%22%20selected%20to%20receive%20%241000%20cash%20or%20a%20%242000%20award.%22%0A%20%20%20%20)%0A%0A%20%20%20%20print(classify_review(%0A%20%20%20%20%20%20%20%20_text_1%2C%20model%2C%20tokenizer%2C%20device%2C%20max_length%3Dtrain_dataset.max_length%0A%20%20%20%20))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20an%20example%20for%20ham%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model%2C%20tokenizer%2C%20train_dataset)%3A%0A%20%20%20%20_text_2%20%3D%20(%0A%20%20%20%20%20%20%20%20%22Hey%2C%20just%20wanted%20to%20check%20if%20we're%20still%20on%22%0A%20%20%20%20%20%20%20%20%22%20for%20dinner%20tonight%3F%20Let%20me%20know!%22%0A%20%20%20%20)%0A%0A%20%20%20%20print(classify_review(%0A%20%20%20%20%20%20%20%20_text_2%2C%20model%2C%20tokenizer%2C%20device%2C%20max_length%3Dtrain_dataset.max_length%0A%20%20%20%20))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20save%20the%20result%20for%20the%20fine-tuning%20by%20this%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(model%2C%20project_root)%3A%0A%20%20%20%20fine_model_path%20%3D%20project_root%20%2F%20%22models%22%20%2F%20%22ch06%22%0A%20%20%20%20os.makedirs(fine_model_path%2C%20exist_ok%3DTrue)%0A%20%20%20%20torch.save(model.state_dict()%2C%20fine_model_path%20%2F%20%22review_classifier.pth%22)%0A%20%20%20%20return%20(fine_model_path%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20And%20load%20the%20model%20by%20this%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(fine_model_path%2C%20model)%3A%0A%20%20%20%20_model_state_dict%20%3D%20torch.load(fine_model_path%20%2F%20%22review_classifier.pth%22%2C%20map_location%3Ddevice%2C%20weights_only%3DTrue)%0A%20%20%20%20model.load_state_dict(_model_state_dict)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
ad7a3daf429c55aace2074534989c98e