import%20marimo%0A%0A__generated_with%20%3D%20%220.18.4%22%0Aapp%20%3D%20marimo.App()%0A%0Awith%20app.setup%3A%0A%20%20%20%20import%20json%0A%20%20%20%20import%20os%0A%20%20%20%20import%20urllib.request%0A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20import%20tensorflow%20as%20tf%0A%20%20%20%20from%20tqdm%20import%20tqdm%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20an%20overall%20pipeline%20to%20download%20model%20parameters%20from%20OpenAI%20Public%2C%20and%20load%20the%20model%20weights%20as%20a%20dictionary.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20download_and_load_gpt2(model_size%2C%20models_dir)%3A%0A%20%20%20%20%23%20Validate%20model%20size%0A%20%20%20%20allowed_sizes%20%3D%20(%22124M%22%2C%20%22355M%22%2C%20%22774M%22%2C%20%221558M%22)%0A%20%20%20%20if%20model_size%20not%20in%20allowed_sizes%3A%0A%20%20%20%20%20%20%20%20raise%20ValueError(f%22Model%20size%20not%20in%20%7Ballowed_sizes%7D%22)%0A%0A%20%20%20%20%23%20Define%20paths%0A%20%20%20%20model_dir%20%3D%20os.path.join(models_dir%2C%20model_size)%0A%20%20%20%20base_url%20%3D%20%22https%3A%2F%2Fopenaipublic.blob.core.windows.net%2Fgpt-2%2Fmodels%22%0A%20%20%20%20backup_base_url%20%3D%20%22https%3A%2F%2Ff001.backblazeb2.com%2Ffile%2FLLMs-from-scratch%2Fgpt2%22%0A%20%20%20%20filenames%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%22checkpoint%22%2C%20%22encoder.json%22%2C%20%22hparams.json%22%2C%0A%20%20%20%20%20%20%20%20%22model.ckpt.data-00000-of-00001%22%2C%20%22model.ckpt.index%22%2C%0A%20%20%20%20%20%20%20%20%22model.ckpt.meta%22%2C%20%22vocab.bpe%22%0A%20%20%20%20%5D%0A%0A%20%20%20%20%23%20Download%20files%0A%20%20%20%20os.makedirs(model_dir%2C%20exist_ok%3DTrue)%0A%20%20%20%20for%20filename%20in%20filenames%3A%0A%20%20%20%20%20%20%20%20file_url%20%3D%20os.path.join(base_url%2C%20model_size%2C%20filename)%0A%20%20%20%20%20%20%20%20backup_url%20%3D%20os.path.join(backup_base_url%2C%20model_size%2C%20filename)%0A%20%20%20%20%20%20%20%20file_path%20%3D%20os.path.join(model_dir%2C%20filename)%0A%20%20%20%20%20%20%20%20download_file(file_url%2C%20file_path%2C%20backup_url)%0A%0A%20%20%20%20%23%20Load%20settings%20and%20params%0A%20%20%20%20tf_ckpt_path%20%3D%20tf.train.latest_checkpoint(model_dir)%0A%20%20%20%20settings%20%3D%20json.load(open(os.path.join(model_dir%2C%20%22hparams.json%22)%2C%20%22r%22%2C%20encoding%3D%22utf-8%22))%0A%20%20%20%20params%20%3D%20load_gpt2_params_from_tf_ckpt(tf_ckpt_path%2C%20settings)%0A%0A%20%20%20%20return%20settings%2C%20params%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20a%20downloader%20function%20supports%20file%20checking%2C%20chunk%20downloading%2C%20and%20backup%20URL%20support.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20download_file(url%2C%20destination%2C%20backup_url%3DNone)%3A%0A%20%20%20%20def%20_attempt_download(download_url)%3A%0A%20%20%20%20%20%20%20%20with%20urllib.request.urlopen(download_url)%20as%20response%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Get%20the%20total%20file%20size%20from%20headers%2C%20defaulting%20to%200%20if%20not%20present%0A%20%20%20%20%20%20%20%20%20%20%20%20file_size%20%3D%20int(response.headers.get(%22Content-Length%22%2C%200))%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Check%20if%20file%20exists%20and%20has%20the%20same%20size%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20os.path.exists(destination)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20file_size_local%20%3D%20os.path.getsize(destination)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20file_size%20%3D%3D%20file_size_local%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20print(f%22File%20already%20exists%20and%20is%20up-to-date%3A%20%7Bdestination%7D%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20True%20%20%23%20Indicate%20success%20without%20re-downloading%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20block_size%20%3D%201024%20%20%23%201%20Kilobyte%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Initialize%20the%20progress%20bar%20with%20total%20file%20size%0A%20%20%20%20%20%20%20%20%20%20%20%20progress_bar_description%20%3D%20os.path.basename(download_url)%0A%20%20%20%20%20%20%20%20%20%20%20%20with%20tqdm(total%3Dfile_size%2C%20unit%3D%22iB%22%2C%20unit_scale%3DTrue%2C%20desc%3Dprogress_bar_description)%20as%20progress_bar%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20with%20open(destination%2C%20%22wb%22)%20as%20file%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20while%20True%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20chunk%20%3D%20response.read(block_size)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20not%20chunk%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20break%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20file.write(chunk)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20progress_bar.update(len(chunk))%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20True%0A%0A%20%20%20%20try%3A%0A%20%20%20%20%20%20%20%20if%20_attempt_download(url)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%0A%20%20%20%20except%20(urllib.error.HTTPError%2C%20urllib.error.URLError)%3A%0A%20%20%20%20%20%20%20%20if%20backup_url%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20print(f%22Primary%20URL%20(%7Burl%7D)%20failed.%20Attempting%20backup%20URL%3A%20%7Bbackup_url%7D%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20try%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20_attempt_download(backup_url)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%0A%20%20%20%20%20%20%20%20%20%20%20%20except%20urllib.error.HTTPError%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20pass%0A%0A%20%20%20%20%20%20%20%20%23%20If%20we%20reach%20here%2C%20both%20attempts%20have%20failed%0A%20%20%20%20%20%20%20%20error_message%20%3D%20(%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22Failed%20to%20download%20from%20both%20primary%20URL%20(%7Burl%7D)%22%0A%20%20%20%20%20%20%20%20%20%20%20%20f%22%7B'%20and%20backup%20URL%20('%20%2B%20backup_url%20%2B%20')'%20if%20backup_url%20else%20''%7D.%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%5CnCheck%20your%20internet%20connection%20or%20the%20file%20availability.%5Cn%22%0A%20%20%20%20%20%20%20%20%20%20%20%20%22For%20help%2C%20visit%3A%20https%3A%2F%2Fgithub.com%2Frasbt%2FLLMs-from-scratch%2Fdiscussions%2F273%22%0A%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20print(error_message)%0A%20%20%20%20except%20Exception%20as%20e%3A%0A%20%20%20%20%20%20%20%20print(f%22An%20unexpected%20error%20occurred%3A%20%7Be%7D%22)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20a%20function%20to%20load%20a%20TensorFlow%20checkpoint%20to%20dictionary%20variable%20with%20proper%20names%20that%20the%20prefix%20%22h%22%20is%20omitted.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20load_gpt2_params_from_tf_ckpt(ckpt_path%2C%20settings)%3A%0A%20%20%20%20%23%20Initialize%20parameters%20dictionary%20with%20empty%20blocks%20for%20each%20layer%0A%20%20%20%20params%20%3D%20%7B%22blocks%22%3A%20%5B%7B%7D%20for%20_%20in%20range(settings%5B%22n_layer%22%5D)%5D%7D%0A%0A%20%20%20%20%23%20Iterate%20over%20each%20variable%20in%20the%20checkpoint%0A%20%20%20%20for%20name%2C%20_%20in%20tf.train.list_variables(ckpt_path)%3A%0A%20%20%20%20%20%20%20%20%23%20Load%20the%20variable%20and%20remove%20singleton%20dimensions%0A%20%20%20%20%20%20%20%20variable_array%20%3D%20np.squeeze(tf.train.load_variable(ckpt_path%2C%20name))%0A%0A%20%20%20%20%20%20%20%20%23%20Process%20the%20variable%20name%20to%20extract%20relevant%20parts%0A%20%20%20%20%20%20%20%20variable_name_parts%20%3D%20name.split(%22%2F%22)%5B1%3A%5D%20%20%23%20Skip%20the%20'model%2F'%20prefix%0A%0A%20%20%20%20%20%20%20%20%23%20Identify%20the%20target%20dictionary%20for%20the%20variable%0A%20%20%20%20%20%20%20%20target_dict%20%3D%20params%0A%20%20%20%20%20%20%20%20if%20variable_name_parts%5B0%5D.startswith(%22h%22)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20layer_number%20%3D%20int(variable_name_parts%5B0%5D%5B1%3A%5D)%0A%20%20%20%20%20%20%20%20%20%20%20%20target_dict%20%3D%20params%5B%22blocks%22%5D%5Blayer_number%5D%0A%0A%20%20%20%20%20%20%20%20%23%20Recursively%20access%20or%20create%20nested%20dictionaries%0A%20%20%20%20%20%20%20%20for%20key%20in%20variable_name_parts%5B1%3A-1%5D%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20target_dict%20%3D%20target_dict.setdefault(key%2C%20%7B%7D)%0A%0A%20%20%20%20%20%20%20%20%23%20Assign%20the%20variable%20array%20to%20the%20last%20key%0A%20%20%20%20%20%20%20%20last_key%20%3D%20variable_name_parts%5B-1%5D%0A%20%20%20%20%20%20%20%20target_dict%5Blast_key%5D%20%3D%20variable_array%0A%0A%20%20%20%20return%20params%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
89f30f72a3aa0e427571de6b8be1e0c8