import%20marimo%0A%0A__generated_with%20%3D%20%220.18.4%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0Awith%20app.setup%3A%0A%20%20%20%20from%20importlib.metadata%20import%20version%0A%0A%20%20%20%20import%20marimo%20as%20mo%0A%20%20%20%20import%20torch%0A%20%20%20%20import%20torch.nn%20as%20nn%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%20Chapter%203%3A%20Coding%20Attention%20Mechanisms%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20print(%22torch%20version%3A%22%2C%20version(%22torch%22))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%203.3%20Attending%20to%20different%20parts%20of%20the%20input%20with%20self-attention%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.3.1%20A%20simple%20self-attention%20mechanism%20without%20trainable%20weights%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Prepare%20embedded%20vectors%20as%20inputs%20to%20attention%20layers.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20inputs%20%3D%20torch.tensor(%0A%20%20%20%20%20%20%20%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B0.43%2C%200.15%2C%200.89%5D%2C%20%20%23%20Your%20%20%20%20%20(x%5E1)%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B0.55%2C%200.87%2C%200.66%5D%2C%20%20%23%20journey%20%20(x%5E2)%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B0.57%2C%200.85%2C%200.64%5D%2C%20%20%23%20starts%20%20%20(x%5E3)%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B0.22%2C%200.58%2C%200.33%5D%2C%20%20%23%20with%20%20%20%20%20(x%5E4)%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B0.77%2C%200.25%2C%200.10%5D%2C%20%20%23%20one%20%20%20%20%20%20(x%5E5)%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B0.05%2C%200.80%2C%200.55%5D%2C%20%20%23%20step%20%20%20%20%20(x%5E6)%0A%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20)%0A%20%20%20%20return%20(inputs%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Attenstion%20score%20%24%5Comega_%7B2i%7D%24%20is%20defined%20as%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cvec%7B%5Comega%7D_%7B2i%7D%20%5Cequiv%20%5Cvec%7Bx%7D_2%20%5Ccdot%20%5Cvec%7Bx%7D_i.%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs)%3A%0A%20%20%20%20query%20%3D%20inputs%5B1%5D%20%20%23%202nd%20input%20token%20is%20the%20query%0A%0A%20%20%20%20attn_scores_2_trainless%20%3D%20torch.empty(inputs.shape%5B0%5D)%0A%20%20%20%20for%20_i%2C%20_x_i%20in%20enumerate(inputs)%3A%0A%20%20%20%20%20%20%20%20%23%20dot%20product%20(transpose%20not%20necessary%20here%20since%20they%20are%201-dim%20vectors)%0A%20%20%20%20%20%20%20%20attn_scores_2_trainless%5B_i%5D%20%3D%20torch.dot(_x_i%2C%20query)%0A%20%20%20%20print(attn_scores_2_trainless)%0A%20%20%20%20return%20attn_scores_2_trainless%2C%20query%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20a%20verification%20of%20dot%20product%20for%20beginners%20by%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cvec%7Bx%7D_0%5Ccdot%5Cvec%7Bx%7D_1%5Cequiv%5Csum_%7Bi%3D0%7D%5E%7B%5Cmathrm%7Bdim%7D(%5Cvec%7Bx%7D_0)%7D(x_0)_i(x_1)_i.%0A%20%20%20%20%24%24%0A%20%20%20%20The%20dot%20product%20has%20meaning%20of%20similarity%20between%20two%20vectors.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs%2C%20query)%3A%0A%20%20%20%20_res%20%3D%200.0%0A%20%20%20%20for%20_idx%2C%20_element%20in%20enumerate(inputs%5B0%5D)%3A%0A%20%20%20%20%20%20%20%20_res%20%2B%3D%20inputs%5B0%5D%5B_idx%5D%20*%20query%5B_idx%5D%0A%0A%20%20%20%20print(_res)%0A%20%20%20%20print(torch.dot(inputs%5B0%5D%2C%20query))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20attention%20should%20be%20normalized%20to%20represents%20weights.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_scores_2_trainless)%3A%0A%20%20%20%20_attn_weights_2_tmp%20%3D%20attn_scores_2_trainless%20%2F%20attn_scores_2_trainless.sum()%0A%0A%20%20%20%20print(%22Attention%20weights%3A%22%2C%20_attn_weights_2_tmp)%0A%20%20%20%20print(%22Sum%3A%22%2C%20_attn_weights_2_tmp.sum())%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20More%20suitable%20function%20for%20normalization%20is%20the%20softmax%20function%20that%20has%20positive%20values%20and%20robust%20gradient%20for%20traning%0A%20%20%20%20%24%24%0A%20%20%20%20%5Csigma(x_i%3B%5Cvec%7Bx%7D)%5Cequiv%5Cfrac%7Be%5E%7Bx_i%7D%7D%7B%5Csum_j%20e%5E%7Bx_j%7D%7D.%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function%0Adef%20softmax_naive(x)%3A%0A%20%20%20%20return%20torch.exp(x)%20%2F%20torch.exp(x).sum(dim%3D0)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20softmax%20function%20also%20leads%20weights%20whose%20total%20is%201.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_scores_2_trainless)%3A%0A%20%20%20%20_attn_weights_2_naive%20%3D%20softmax_naive(attn_scores_2_trainless)%0A%0A%20%20%20%20print(%22Attention%20weights%3A%22%2C%20_attn_weights_2_naive)%0A%20%20%20%20print(%22Sum%3A%22%2C%20_attn_weights_2_naive.sum())%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20PyTorch%20softmax%20function%20is%20more%20stable%20with%20respect%20to%20numerical%20errors%20because%20it%20uses%20%5Bmax-trick%20and%20LogSumExp%20trick%5D(https%3A%2F%2Fdiscuss.pytorch.org%2Ft%2Fjustification-for-logsoftmax-being-better-than-log-softmax%2F140130%2F3).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_scores_2_trainless)%3A%0A%20%20%20%20attn_weights_2_trainless%20%3D%20torch.softmax(attn_scores_2_trainless%2C%20dim%3D0)%0A%0A%20%20%20%20print(%22Attention%20weights%3A%22%2C%20attn_weights_2_trainless)%0A%20%20%20%20print(%22Sum%3A%22%2C%20attn_weights_2_trainless.sum())%0A%20%20%20%20return%20(attn_weights_2_trainless%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20attention%20scores%20are%20used%20as%20weights%20for%20weighted%20average%20of%20embedded%20vectors.%0A%20%20%20%20The%20result%20is%20called%20as%20context%20vectors%20%24%5Cvec%7Bz%7D%5E%7B(i)%7D%24.%0A%0A%20%20%20%20Formaly%2C%20the%20context%20vectors%20%24%5Cvec%7Bz%7D%5E%7B(i)%7D%24%20are%20defined%20as%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cvec%7Bz%7D%5E%7B(i)%7D%0A%20%20%20%20%5Cequiv%5Csum_j%5Cmathrm%7BSoftmax%7D_j(%5Comega_%7Bij%7D)%5Cvec%7Bx%7D_j%0A%20%20%20%20%5Cequiv%5Cfrac%7B%5Csum_j%5Comega_%7Bij%7D%5Cvec%7Bx%7D_j%7D%7B%5Csum_k%5Comega_%7Bik%7D%7D%0A%20%20%20%20%5Cequiv%5Cfrac%7B(%5Csum_j%5Cvec%7Bx%7D_j%5Cvec%7Bx%7D_j%5ET)%5Cvec%7Bx%7D_i%7D%7B%5Csum_k%5Cvec%7Bx%7D_k%5Ccdot%5Cvec%7Bx%7D_i%7D%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_weights_2_trainless%2C%20inputs%2C%20query)%3A%0A%20%20%20%20context_vec_2_trainless%20%3D%20torch.zeros(query.shape)%0A%20%20%20%20for%20_i%2C%20_x_i%20in%20enumerate(inputs)%3A%0A%20%20%20%20%20%20%20%20context_vec_2_trainless%20%2B%3D%20attn_weights_2_trainless%5B_i%5D%20*%20_x_i%0A%0A%20%20%20%20print(context_vec_2_trainless)%0A%20%20%20%20return%20(context_vec_2_trainless%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.3.2%20Computing%20attention%20weights%20for%20all%20input%20tokens%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Calculate%20context%20vectors%20with%20respect%20to%20all%20query%20vectors.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs)%3A%0A%20%20%20%20_attn_scores%20%3D%20torch.empty(6%2C%206)%0A%0A%20%20%20%20for%20_i%2C%20_x_i%20in%20enumerate(inputs)%3A%0A%20%20%20%20%20%20%20%20for%20_j%2C%20_x_j%20in%20enumerate(inputs)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20_attn_scores%5B_i%2C%20_j%5D%20%3D%20torch.dot(_x_i%2C%20_x_j)%0A%20%20%20%20print(_attn_scores)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20In%20Python%2C%20such%20double%20for-loops%20are%20inefficient.%20The%20following%20matrix%20product%20is%20preferred%20for%20computation%20efficiency.%20This%20result%20is%20the%20same%20with%20the%20result%20of%20the%20previous%20cell.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs)%3A%0A%20%20%20%20attn_scores%20%3D%20inputs%20%40%20inputs.T%0A%20%20%20%20print(attn_scores)%0A%20%20%20%20return%20(attn_scores%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20can%20apply%20softmax%20function%20to%20each%20rows.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_scores)%3A%0A%20%20%20%20attn_weights%20%3D%20torch.softmax(attn_scores%2C%20dim%3D-1)%0A%20%20%20%20print(attn_weights)%0A%20%20%20%20return%20(attn_weights%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20normalization%20is%20verified%20by%20this.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_weights)%3A%0A%20%20%20%20row_2_sum%20%3D%20sum(%5B0.1385%2C%200.2379%2C%200.2333%2C%200.1240%2C%200.1082%2C%200.1581%5D)%0A%20%20%20%20print(%22Row%202%20sum%3A%22%2C%20row_2_sum)%0A%0A%20%20%20%20print(%22All%20row%20sums%3A%22%2C%20attn_weights.sum(dim%3D-1))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Weighted%20averages%20are%20also%20simplified%20by%20using%20matrix-vector%20products.%0A%20%20%20%20Finally%2C%20we%20obtain%20the%20following%20formula%3A%0A%20%20%20%20%24%24%0A%20%20%20%20Z%5Cequiv%5Cmathrm%7BSoftmax%7D(XX%5ET)X%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_weights%2C%20inputs)%3A%0A%20%20%20%20all_context_vecs%20%3D%20attn_weights%20%40%20inputs%0A%20%20%20%20print(all_context_vecs)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%202nd%20slice%20is%20the%20exactly%20same%20with%20the%20result%20of%20the%20previous%20section.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(context_vec_2_trainless)%3A%0A%20%20%20%20print(%22Previous%202nd%20context%20vector%3A%22%2C%20context_vec_2_trainless)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%203.4%20Implementing%20self-attention%20with%20trainable%20weights%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.4.1%20Computing%20the%20attention%20weights%20step%20by%20step%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Use%20these%20inputs%20and%20parameters%20to%20show%20how%20to%20introduce%20trainable%20attention%20layer.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs)%3A%0A%20%20%20%20x_2%20%3D%20inputs%5B1%5D%20%20%23%20second%20input%20element%0A%20%20%20%20d_in%20%3D%20inputs.shape%5B1%5D%20%20%23%20the%20input%20embedding%20size%2C%20d%3D3%0A%20%20%20%20d_out%20%3D%202%20%20%23%20the%20output%20embedding%20size%2C%20d%3D2%0A%20%20%20%20return%20d_in%2C%20d_out%2C%20x_2%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20These%20are%20trainable%20weight%20matricies%20(or%20linear%20projections).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(d_in%2C%20d_out)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20%23%20requires_grad%3DFalse%20is%20for%20simplification%20purposes%20only%20(True%20is%20required%20for%20training)%0A%20%20%20%20W_query%20%3D%20torch.nn.Parameter(torch.rand(d_in%2C%20d_out)%2C%20requires_grad%3DFalse)%0A%20%20%20%20W_key%20%3D%20torch.nn.Parameter(torch.rand(d_in%2C%20d_out)%2C%20requires_grad%3DFalse)%0A%20%20%20%20W_value%20%3D%20torch.nn.Parameter(torch.rand(d_in%2C%20d_out)%2C%20requires_grad%3DFalse)%0A%20%20%20%20return%20W_key%2C%20W_query%2C%20W_value%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20linear%20projection%20are%20like%20these%0A%20%20%20%20%24%24%0A%20%20%20%20Q_2%5Cequiv%20%5Cvec%7Bx%7D_2%20W%5EQ%20%2C%5Cquad%0A%20%20%20%20K_2%5Cequiv%20%5Cvec%7Bx%7D_2%20W%5EK%2C%5Cquad%0A%20%20%20%20V_2%5Cequiv%20%5Cvec%7Bx%7D_2%20W%5EV%2C%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(W_key%2C%20W_query%2C%20W_value%2C%20x_2)%3A%0A%20%20%20%20query_2%20%3D%20x_2%20%40%20W_query%20%20%23%20_2%20because%20it's%20with%20respect%20to%20the%202nd%20input%20element%0A%20%20%20%20key_2%20%3D%20x_2%20%40%20W_key%0A%20%20%20%20value_2%20%3D%20x_2%20%40%20W_value%0A%0A%20%20%20%20print(query_2)%0A%20%20%20%20return%20(query_2%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Linear%20projection%20for%20all%20inputs%20are%0A%20%20%20%20%24%24%0A%20%20%20%20K%3DXW%5EK%2C%20V%3DXW%5EV.%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(W_key%2C%20W_value%2C%20inputs)%3A%0A%20%20%20%20keys%20%3D%20inputs%20%40%20W_key%0A%20%20%20%20values%20%3D%20inputs%20%40%20W_value%0A%0A%20%20%20%20print(%22keys.shape%3A%22%2C%20keys.shape)%0A%20%20%20%20print(%22values.shape%3A%22%2C%20values.shape)%0A%20%20%20%20return%20keys%2C%20values%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20attention%20score%20%24%5Comega_%7B22%7D%5Cequiv%20Q_2%5Ccdot%20K_2%24%20is%20this.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(keys%2C%20query_2)%3A%0A%20%20%20%20keys_2%20%3D%20keys%5B1%5D%20%20%23%20Python%20starts%20index%20at%200%0A%20%20%20%20attn_score_22%20%3D%20query_2.dot(keys_2)%0A%20%20%20%20print(attn_score_22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Scores%20for%20all%20keys%20are%20calculated%20by%20%24%5Comega_2%5Cequiv%20Q_2%20K%24.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(keys%2C%20query_2)%3A%0A%20%20%20%20attn_scores_2%20%3D%20query_2%20%40%20keys.T%20%20%23%20All%20attention%20scores%20for%20given%20query%0A%20%20%20%20print(attn_scores_2)%0A%20%20%20%20return%20(attn_scores_2%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Take%20the%20attention%20scores%20by%20Scaled%20Dot-Product%20Attention%20way.%0A%20%20%20%20The%20attention%20score%20is%20scaled%20by%20square%20root%20of%20the%20dimension%20of%20keys%20like%0A%20%20%20%20%24%24%0A%20%20%20%20%5Calpha_2%5Cequiv%5Cmathrm%7BSoftmax%7D%5Cfrac%7B%5Comega_2%7D%7B%5Csqrt%7B%5Cmathrm%7Bdim%7D(K_i)%7D%7D%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_scores_2%2C%20keys)%3A%0A%20%20%20%20d_k%20%3D%20keys.shape%5B1%5D%0A%20%20%20%20attn_weights_2%20%3D%20torch.softmax(attn_scores_2%20%2F%20d_k**0.5%2C%20dim%3D-1)%0A%20%20%20%20print(attn_weights_2)%0A%20%20%20%20return%20(attn_weights_2%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Finally%2C%20we%20get%20trainable%20context%20vector%20as%0A%20%20%20%20%24%24%0A%20%20%20%20%5Cvec%7Bz%7D%5E2%5Cequiv%5Cmathrm%7BSoftmax%7D%5Cleft(%5Cfrac%7BQ_2K%7D%7B%5Csqrt%7B%5Cmathrm%7Bdim%7D(K_i)%7D%7D%5Cright)V.%0A%20%20%20%20%24%24%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_weights_2%2C%20values)%3A%0A%20%20%20%20context_vec_2%20%3D%20attn_weights_2%20%40%20values%0A%20%20%20%20print(context_vec_2)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.4.2%20Implementing%20a%20compact%20SelfAttention%20class%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.class_definition%0Aclass%20SelfAttention_v1(nn.Module)%3A%0A%20%20%20%20def%20__init__(self%2C%20d_in%2C%20d_out)%3A%0A%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20%23%20trainable%20parameters%20are%20stored%20as%20member%20variables%0A%20%20%20%20%20%20%20%20self.W_query%20%3D%20nn.Parameter(torch.rand(d_in%2C%20d_out))%0A%20%20%20%20%20%20%20%20self.W_key%20%3D%20nn.Parameter(torch.rand(d_in%2C%20d_out))%0A%20%20%20%20%20%20%20%20self.W_value%20%3D%20nn.Parameter(torch.rand(d_in%2C%20d_out))%0A%0A%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20keys%20%3D%20x%20%40%20self.W_key%0A%20%20%20%20%20%20%20%20queries%20%3D%20x%20%40%20self.W_query%0A%20%20%20%20%20%20%20%20values%20%3D%20x%20%40%20self.W_value%0A%0A%20%20%20%20%20%20%20%20attn_scores%20%3D%20queries%20%40%20keys.T%20%20%23%20omega%0A%20%20%20%20%20%20%20%20attn_weights%20%3D%20torch.softmax(attn_scores%20%2F%20keys.shape%5B-1%5D%20**%200.5%2C%20dim%3D-1)%0A%0A%20%20%20%20%20%20%20%20context_vec%20%3D%20attn_weights%20%40%20values%0A%20%20%20%20%20%20%20%20return%20context_vec%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20trainable%20context%20vectors%20are%20calculated%20by%20this%20instance.%20The%202nd%20slice%20of%20the%20result%20is%20the%20same%20with%20the%20result%20of%20the%20previous%20section.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(d_in%2C%20d_out%2C%20inputs)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%20%20%20%20sa_v1%20%3D%20SelfAttention_v1(d_in%2C%20d_out)%0A%20%20%20%20print(sa_v1(inputs))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20By%20using%20%60nn.Linear%60%2C%20the%20class%20definition%20is%20more%20simplified%20and%20more%20stable%20because%20of%20a%20good%20initialization%20scheme%20called%20as%20%5Bthe%20Kaiming%20Uniform%20initialization%5D(https%3A%2F%2Fdocs.pytorch.org%2Fdocs%2Fstable%2Fnn.init.html%23torch.nn.init.kaiming_uniform_).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.class_definition%0Aclass%20SelfAttention_v2(nn.Module)%3A%0A%20%20%20%20def%20__init__(self%2C%20d_in%2C%20d_out%2C%20qkv_bias%3DFalse)%3A%0A%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20self.W_query%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%20%20%20%20%20%20%20%20self.W_key%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%20%20%20%20%20%20%20%20self.W_value%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%0A%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20keys%20%3D%20self.W_key(x)%0A%20%20%20%20%20%20%20%20queries%20%3D%20self.W_query(x)%0A%20%20%20%20%20%20%20%20values%20%3D%20self.W_value(x)%0A%0A%20%20%20%20%20%20%20%20attn_scores%20%3D%20queries%20%40%20keys.T%0A%20%20%20%20%20%20%20%20attn_weights%20%3D%20torch.softmax(attn_scores%20%2F%20keys.shape%5B-1%5D%20**%200.5%2C%20dim%3D-1)%0A%0A%20%20%20%20%20%20%20%20context_vec%20%3D%20attn_weights%20%40%20values%0A%20%20%20%20%20%20%20%20return%20context_vec%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20output%20is%20different%20with%20the%20previous%20results%20because%20of%20such%20initialization%20schemes.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(d_in%2C%20d_out%2C%20inputs)%3A%0A%20%20%20%20torch.manual_seed(789)%0A%20%20%20%20sa_v2%20%3D%20SelfAttention_v2(d_in%2C%20d_out)%0A%20%20%20%20print(sa_v2(inputs))%0A%20%20%20%20return%20(sa_v2%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%203.5%20Hiding%20future%20words%20with%20causal%20attention%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.5.1%20Applying%20a%20causal%20attention%20mask%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20an%20initial%20attention%20weights%20like%20previous.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs%2C%20sa_v2)%3A%0A%20%20%20%20%23%20Reuse%20the%20query%20and%20key%20weight%20matrices%20of%20the%0A%20%20%20%20%23%20SelfAttention_v2%20object%20from%20the%20previous%20section%20for%20convenience%0A%20%20%20%20_queries%20%3D%20sa_v2.W_query(inputs)%0A%20%20%20%20_keys%20%3D%20sa_v2.W_key(inputs)%0A%20%20%20%20_attn_scores%20%3D%20_queries%20%40%20_keys.T%0A%0A%20%20%20%20attn_weights_sa_v2%20%3D%20torch.softmax(_attn_scores%20%2F%20_keys.shape%5B-1%5D%20**%200.5%2C%20dim%3D-1)%0A%20%20%20%20print(attn_weights_sa_v2)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Let's%20mask%20the%20upper%20right%20of%20it%20to%20ensure%20causal%20token%20prediction%20tasks.%0A%20%20%20%20Such%20binary%20mask%20is%20easily%20created%20by%20%60torch.tril%60.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_scores)%3A%0A%20%20%20%20_context_length%20%3D%20attn_scores.shape%5B0%5D%0A%20%20%20%20mask_simple%20%3D%20torch.tril(torch.ones(_context_length%2C%20_context_length))%0A%20%20%20%20print(mask_simple)%0A%20%20%20%20return%20(mask_simple%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20masking%20is%20simple%20multiplication%20of%20these%20%24W_a%5Cmathbb%7B1%7D_%7Bi%5Cgeq%20j%7D%24.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_weights%2C%20mask_simple)%3A%0A%20%20%20%20masked_simple%20%3D%20attn_weights%20*%20mask_simple%0A%20%20%20%20print(masked_simple)%0A%20%20%20%20return%20(masked_simple%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Normalize%20each%20rows%20to%20use%20it%20as%20weights.%20The%20normalization%20should%20be%20done%20after%20the%20masking%20to%20prevent%20data%20leak%20from%20futures.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(masked_simple)%3A%0A%20%20%20%20row_sums%20%3D%20masked_simple.sum(dim%3D-1%2C%20keepdim%3DTrue)%0A%20%20%20%20masked_simple_norm%20%3D%20masked_simple%20%2F%20row_sums%0A%20%20%20%20print(masked_simple_norm)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20For%20more%20efficiency%2C%20this%20masking%20is%20replaced%20by%20filling%20%60-inf%60%20to%20upper%20right%20of%20attention%20scores%20before%20taking%20Softmax.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_scores%2C%20context_length)%3A%0A%20%20%20%20mask%20%3D%20torch.triu(torch.ones(context_length%2C%20context_length)%2C%20diagonal%3D1)%0A%20%20%20%20masked%20%3D%20attn_scores.masked_fill(mask.bool()%2C%20-torch.inf)%0A%20%20%20%20print(masked)%0A%20%20%20%20return%20(masked%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%24%5Cexp(-%5Cinfty)%3D0%24%20leads%20the%20same%20result%20with%20the%20masking.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(keys%2C%20masked)%3A%0A%20%20%20%20causal_attn_weights%20%3D%20torch.softmax(masked%20%2F%20keys.shape%5B-1%5D%20**%200.5%2C%20dim%3D-1)%0A%20%20%20%20print(causal_attn_weights)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.5.2%20Masking%20additional%20attention%20weights%20with%20dropout%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20introduct%20dropout%20to%20prevent%20overfitting%20to%20specific%20tokens.%20We%20drop%2050%25%20of%20elements%20and%20the%20remaining%20values%20are%20doubled%20to%20keep%20the%20total%20scale.%20This%20dropout%20layer%20is%20applied%20only%20for%20training.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20torch.manual_seed(123)%0A%20%20%20%20dropout%20%3D%20torch.nn.Dropout(0.5)%20%20%23%20dropout%20rate%20of%2050%25%0A%20%20%20%20example%20%3D%20torch.ones(6%2C%206)%20%20%23%20create%20a%20matrix%20of%20ones%0A%0A%20%20%20%20print(dropout(example))%0A%20%20%20%20return%20(dropout%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20dropout%20layer%20is%20applied%20after%20attention%20weights%20are%20calculated.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(attn_weights%2C%20dropout)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%20%20%20%20print(dropout(attn_weights))%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.5.3%20Implementing%20a%20compact%20causal%20self-attention%20class%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20We%20introduce%20such%20causal%20attention%2C%20dropout%2C%20and%20batch%20inference%20to%20attention%20layer.%20The%20batch%20input%20for%20testing%20is%20as%20follows.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(inputs)%3A%0A%20%20%20%20batch%20%3D%20torch.stack((inputs%2C%20inputs)%2C%20dim%3D0)%0A%20%20%20%20%23%202%20inputs%20with%206%20tokens%20each%2C%20and%20each%20token%20has%20embedding%20dimension%203%0A%20%20%20%20print(batch.shape)%20%20%0A%20%20%20%20return%20(batch%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20an%20example%20for%20causal%20attention%20layer.%20%5B%60register_buffer%60%5D(https%3A%2F%2Fdocs.pytorch.org%2Fdocs%2Fstable%2Fgenerated%2Ftorch.nn.Module.html%23torch.nn.Module.register_buffer)%20is%20used%20to%20define%20constant%20(non-trained)%20parameters%20and%20ensure%20it%20is%20allocated%20on%20suitable%20device.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.class_definition%0Aclass%20CausalAttention(nn.Module)%3A%0A%20%20%20%20def%20__init__(self%2C%20d_in%2C%20d_out%2C%20context_length%2C%20dropout%2C%20qkv_bias%3DFalse)%3A%0A%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20self.d_out%20%3D%20d_out%0A%20%20%20%20%20%20%20%20self.W_query%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%20%20%20%20%20%20%20%20self.W_key%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%20%20%20%20%20%20%20%20self.W_value%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%0A%20%20%20%20%20%20%20%20%23%20NEW%0A%20%20%20%20%20%20%20%20self.dropout%20%3D%20nn.Dropout(dropout)%0A%20%20%20%20%20%20%20%20%23%20NEW%3A%20allocate%20maximum%20mask%20to%20be%20sliced%20later%0A%20%20%20%20%20%20%20%20self.register_buffer(%0A%20%20%20%20%20%20%20%20%20%20%20%20%22mask%22%2C%20torch.triu(torch.ones(context_length%2C%20context_length)%2C%20diagonal%3D1)%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20b%2C%20num_tokens%2C%20d_in%20%3D%20x.shape%20%20%23%20New%20batch%20dimension%20b%0A%0A%20%20%20%20%20%20%20%20%23%20For%20inputs%20where%20%60num_tokens%60%20exceeds%20%60context_length%60%2C%20this%20will%20result%20in%20errors%0A%20%20%20%20%20%20%20%20%23%20in%20the%20mask%20creation%20further%20below.%0A%20%20%20%20%20%20%20%20%23%20In%20practice%2C%20this%20is%20not%20a%20problem%20since%20the%20LLM%20(chapters%204-7)%20ensures%20that%20inputs%0A%20%20%20%20%20%20%20%20%23%20do%20not%20exceed%20%60context_length%60%20before%20reaching%20this%20forward%20method.%0A%20%20%20%20%20%20%20%20keys%20%3D%20self.W_key(x)%20%20%20%20%23%20(B%2CN%2Cd_out)%0A%20%20%20%20%20%20%20%20queries%20%3D%20self.W_query(x)%20%20%20%23%20(B%2CN%2Cd_out)%0A%20%20%20%20%20%20%20%20values%20%3D%20self.W_value(x)%20%20%20%20%23%20(B%2CN%2Cd_out)%0A%0A%20%20%20%20%20%20%20%20%23%20Changed%20transpose%3A%20(B%2CN%2Cd_out)%20-%3E%20(B%2Cd_out%2CN)%0A%20%20%20%20%20%20%20%20attn_scores%20%3D%20queries%20%40%20keys.transpose(1%2C%202)%20%20%20%20%23%20(B%2CN%2CN)%0A%20%20%20%20%20%20%20%20%23%20New%2C%20_%20ops%20are%20in-place%0A%20%20%20%20%20%20%20%20%23%20%60%3Anum_tokens%60%20to%20account%20for%20cases%20where%20the%20number%20of%20tokens%20in%20the%20batch%20is%20smaller%20than%20the%20supported%20context_size%0A%20%20%20%20%20%20%20%20attn_scores.masked_fill_(%20%20%0A%20%20%20%20%20%20%20%20%20%20%20%20self.mask.bool()%5B%3Anum_tokens%2C%20%3Anum_tokens%5D%2C%20-torch.inf%0A%20%20%20%20%20%20%20%20)%20%20%0A%20%20%20%20%20%20%20%20%23%20normalization%0A%20%20%20%20%20%20%20%20attn_weights%20%3D%20torch.softmax(attn_scores%20%2F%20keys.shape%5B-1%5D%20**%200.5%2C%20dim%3D-1)%0A%20%20%20%20%20%20%20%20attn_weights%20%3D%20self.dropout(attn_weights)%20%20%23%20New%0A%0A%20%20%20%20%20%20%20%20context_vec%20%3D%20attn_weights%20%40%20values%20%23%20(B%2CN%2Cd_out)%0A%20%20%20%20%20%20%20%20return%20context_vec%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20This%20is%20an%20example%20to%20use%20the%20forward%20path.%20Check%20the%20input%20and%20output%20shapes.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(batch%2C%20d_in%2C%20d_out)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20%23%20%60batch%60%20has%20(B%2CN%2CC)%20shape%0A%20%20%20%20_context_length%20%3D%20batch.shape%5B1%5D%0A%20%20%20%20ca%20%3D%20CausalAttention(d_in%2C%20d_out%2C%20_context_length%2C%200.0)%0A%0A%20%20%20%20_context_vecs%20%3D%20ca(batch)%0A%0A%20%20%20%20print(_context_vecs)%0A%0A%20%20%20%20print(f%22%7Bbatch.shape%3D%7D%22)%0A%20%20%20%20print(f%22%7B_context_vecs.shape%3D%7D%22)%0A%20%20%20%20print(f%22%7Bd_in%3D%7D%2C%20%7Bd_out%3D%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%203.6%20Extending%20single-head%20attention%20to%20multi-head%20attention%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.6.1%20Stacking%20multiple%20single-head%20attention%20layers%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Let's%20extend%20the%20attention%20layer%20to%20multi-head.%20This%20is%20a%20naive%20extetion%20from%20single-head%20one.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.class_definition%0Aclass%20MultiHeadAttentionWrapper(nn.Module)%3A%0A%20%20%20%20def%20__init__(self%2C%20d_in%2C%20d_out%2C%20context_length%2C%20dropout%2C%20num_heads%2C%20qkv_bias%3DFalse)%3A%0A%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20self.heads%20%3D%20nn.ModuleList(%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20CausalAttention(d_in%2C%20d_out%2C%20context_length%2C%20dropout%2C%20qkv_bias)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20for%20_%20in%20range(num_heads)%0A%20%20%20%20%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20%23%20just%20apply%20single-head%20attentions%20and%20concatenate%20their%20outputs%0A%20%20%20%20%20%20%20%20return%20torch.cat(%5Bhead(x)%20for%20head%20in%20self.heads%5D%2C%20dim%3D-1)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20The%20output%20shape%20is%20%24(B%2CN%2Cd_%7Bout%7D%5Ctimes2)%24.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(batch)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20_context_length%20%3D%20batch.shape%5B1%5D%20%20%23%20This%20is%20the%20number%20of%20tokens%0A%20%20%20%20_d_in%2C%20_d_out%20%3D%203%2C%202%0A%20%20%20%20_mha%20%3D%20MultiHeadAttentionWrapper(_d_in%2C%20_d_out%2C%20_context_length%2C%200.0%2C%20num_heads%3D2)%0A%0A%20%20%20%20_context_vecs%20%3D%20_mha(batch)%0A%20%20%20%20print(_context_vecs)%0A%20%20%20%20print(f%22%7Bbatch.shape%3D%7D%22)%0A%20%20%20%20print(f%22%7B_context_vecs.shape%3D%7D%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20%23%23%23%203.6.2%20Implementing%20multi-head%20attention%20with%20weight%20splits%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20mo.md(r%22%22%22%0A%20%20%20%20Simplify%20the%20multi-head%20attention%20layer%20by%20self-contained%20way%20and%20avoid%20loops%20by%20using%20matrix%20products.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.class_definition%0Aclass%20MultiHeadAttention(nn.Module)%3A%0A%20%20%20%20def%20__init__(self%2C%20d_in%2C%20d_out%2C%20context_length%2C%20dropout%2C%20num_heads%2C%20qkv_bias%3DFalse)%3A%0A%20%20%20%20%20%20%20%20super().__init__()%0A%20%20%20%20%20%20%20%20assert%20d_out%20%25%20num_heads%20%3D%3D%200%2C%20%22d_out%20must%20be%20divisible%20by%20num_heads%22%0A%0A%20%20%20%20%20%20%20%20self.d_out%20%3D%20d_out%0A%20%20%20%20%20%20%20%20self.num_heads%20%3D%20num_heads%0A%20%20%20%20%20%20%20%20self.head_dim%20%3D%20(%0A%20%20%20%20%20%20%20%20%20%20%20%20d_out%20%2F%2F%20num_heads%0A%20%20%20%20%20%20%20%20)%20%20%23%20Reduce%20the%20projection%20dim%20to%20match%20desired%20output%20dim%0A%0A%20%20%20%20%20%20%20%20self.W_query%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%20%20%20%20%20%20%20%20self.W_key%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%20%20%20%20%20%20%20%20self.W_value%20%3D%20nn.Linear(d_in%2C%20d_out%2C%20bias%3Dqkv_bias)%0A%20%20%20%20%20%20%20%20self.out_proj%20%3D%20nn.Linear(d_out%2C%20d_out)%20%20%23%20Linear%20layer%20to%20combine%20head%20outputs%0A%20%20%20%20%20%20%20%20self.dropout%20%3D%20nn.Dropout(dropout)%0A%20%20%20%20%20%20%20%20%23%20constant%20parameters%0A%20%20%20%20%20%20%20%20self.register_buffer(%0A%20%20%20%20%20%20%20%20%20%20%20%20%22mask%22%2C%20torch.triu(torch.ones(context_length%2C%20context_length)%2C%20diagonal%3D1)%0A%20%20%20%20%20%20%20%20)%0A%0A%20%20%20%20def%20forward(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20b%2C%20num_tokens%2C%20d_in%20%3D%20x.shape%0A%20%20%20%20%20%20%20%20%23%20As%20in%20%60CausalAttention%60%2C%20for%20inputs%20where%20%60num_tokens%60%20exceeds%20%60context_length%60%2C%0A%20%20%20%20%20%20%20%20%23%20this%20will%20result%20in%20errors%20in%20the%20mask%20creation%20further%20below.%0A%20%20%20%20%20%20%20%20%23%20In%20practice%2C%20this%20is%20not%20a%20problem%20since%20the%20LLM%20(chapters%204-7)%20ensures%20that%20inputs%0A%20%20%20%20%20%20%20%20%23%20do%20not%20exceed%20%60context_length%60%20before%20reaching%20this%20forward%20method.%0A%0A%20%20%20%20%20%20%20%20keys%20%3D%20self.W_key(x)%20%20%23%20Shape%3A%20(b%2C%20num_tokens%2C%20d_out)%0A%20%20%20%20%20%20%20%20queries%20%3D%20self.W_query(x)%0A%20%20%20%20%20%20%20%20values%20%3D%20self.W_value(x)%0A%0A%20%20%20%20%20%20%20%20%23%20We%20implicitly%20split%20the%20matrix%20by%20adding%20a%20%60num_heads%60%20dimension%0A%20%20%20%20%20%20%20%20%23%20Unroll%20last%20dim%3A%20(b%2C%20num_tokens%2C%20d_out)%20-%3E%20(b%2C%20num_tokens%2C%20num_heads%2C%20head_dim)%0A%20%20%20%20%20%20%20%20%23%20So%2C%20d_out%20%3D%20num_heads%20*%20head_dim%20(see%20this%20constructor)%0A%20%20%20%20%20%20%20%20keys%20%3D%20keys.view(b%2C%20num_tokens%2C%20self.num_heads%2C%20self.head_dim)%0A%20%20%20%20%20%20%20%20values%20%3D%20values.view(b%2C%20num_tokens%2C%20self.num_heads%2C%20self.head_dim)%0A%20%20%20%20%20%20%20%20queries%20%3D%20queries.view(b%2C%20num_tokens%2C%20self.num_heads%2C%20self.head_dim)%0A%0A%20%20%20%20%20%20%20%20%23%20Transpose%3A%20(b%2C%20num_tokens%2C%20num_heads%2C%20head_dim)%20-%3E%20(b%2C%20num_heads%2C%20num_tokens%2C%20head_dim)%0A%20%20%20%20%20%20%20%20keys%20%3D%20keys.transpose(1%2C%202)%0A%20%20%20%20%20%20%20%20queries%20%3D%20queries.transpose(1%2C%202)%0A%20%20%20%20%20%20%20%20values%20%3D%20values.transpose(1%2C%202)%0A%0A%20%20%20%20%20%20%20%20%23%20Compute%20scaled%20dot-product%20attention%20(aka%20self-attention)%20with%20a%20causal%20mask%0A%20%20%20%20%20%20%20%20%23%20Shape%3A%20(b%2C%20num_heads%2C%20num_tokens%2C%20num_tokens)%0A%20%20%20%20%20%20%20%20attn_scores%20%3D%20queries%20%40%20keys.transpose(2%2C%203)%20%20%23%20Dot%20product%20for%20each%20head%0A%0A%20%20%20%20%20%20%20%20%23%20Original%20mask%20truncated%20to%20the%20number%20of%20tokens%20and%20converted%20to%20boolean%0A%20%20%20%20%20%20%20%20mask_bool%20%3D%20self.mask.bool()%5B%3Anum_tokens%2C%20%3Anum_tokens%5D%0A%0A%20%20%20%20%20%20%20%20%23%20Use%20the%20mask%20to%20fill%20attention%20scores%0A%20%20%20%20%20%20%20%20attn_scores.masked_fill_(mask_bool%2C%20-torch.inf)%0A%0A%20%20%20%20%20%20%20%20attn_weights%20%3D%20torch.softmax(attn_scores%20%2F%20keys.shape%5B-1%5D%20**%200.5%2C%20dim%3D-1)%0A%20%20%20%20%20%20%20%20attn_weights%20%3D%20self.dropout(attn_weights)%0A%0A%20%20%20%20%20%20%20%20%23%20Shape%3A%20(b%2C%20num_tokens%2C%20num_heads%2C%20head_dim)%0A%20%20%20%20%20%20%20%20context_vec%20%3D%20(attn_weights%20%40%20values).transpose(1%2C%202)%0A%0A%20%20%20%20%20%20%20%20%23%20Combine%20heads%2C%20where%20self.d_out%20%3D%20self.num_heads%20*%20self.head_dim%0A%20%20%20%20%20%20%20%20context_vec%20%3D%20context_vec.contiguous().view(b%2C%20num_tokens%2C%20self.d_out)%0A%20%20%20%20%20%20%20%20context_vec%20%3D%20self.out_proj(context_vec)%20%20%23%20optional%20projection%0A%0A%20%20%20%20%20%20%20%20return%20context_vec%0A%0A%0A%40app.cell%0Adef%20_(batch)%3A%0A%20%20%20%20torch.manual_seed(123)%0A%0A%20%20%20%20batch_size%2C%20context_length%2C%20_d_in%20%3D%20batch.shape%0A%20%20%20%20_d_out%20%3D%202%0A%20%20%20%20mha%20%3D%20MultiHeadAttention(_d_in%2C%20_d_out%2C%20context_length%2C%200.0%2C%20num_heads%3D2)%0A%0A%20%20%20%20context_vecs%20%3D%20mha(batch)%0A%0A%20%20%20%20print(context_vecs)%0A%20%20%20%20print(%22context_vecs.shape%3A%22%2C%20context_vecs.shape)%0A%20%20%20%20return%20(context_length%2C)%0A%0A%0A%40app.cell%0Adef%20_()%3A%0A%20%20%20%20%23%20(b%2C%20num_heads%2C%20num_tokens%2C%20head_dim)%20%3D%20(1%2C%202%2C%203%2C%204)%0A%20%20%20%20a%20%3D%20torch.tensor(%0A%20%20%20%20%20%20%20%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B0.2745%2C%200.6584%2C%200.2775%2C%200.8573%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B0.8993%2C%200.0390%2C%200.9268%2C%200.7388%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B0.7179%2C%200.7058%2C%200.9156%2C%200.4340%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B0.0772%2C%200.3565%2C%200.1479%2C%200.5331%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B0.4066%2C%200.2318%2C%200.4545%2C%200.9737%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5B0.4606%2C%200.5159%2C%200.4220%2C%200.5786%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5D%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20)%0A%0A%20%20%20%20print(a%20%40%20a.transpose(2%2C%203))%0A%20%20%20%20return%20(a%2C)%0A%0A%0A%40app.cell%0Adef%20_(a)%3A%0A%20%20%20%20_first_head%20%3D%20a%5B0%2C%200%2C%20%3A%2C%20%3A%5D%0A%20%20%20%20first_res%20%3D%20_first_head%20%40%20_first_head.T%0A%20%20%20%20print(%22First%20head%3A%5Cn%22%2C%20first_res)%0A%0A%20%20%20%20_second_head%20%3D%20a%5B0%2C%201%2C%20%3A%2C%20%3A%5D%0A%20%20%20%20_second_res%20%3D%20_second_head%20%40%20_second_head.T%0A%20%20%20%20print(%22%5CnSecond%20head%3A%5Cn%22%2C%20_second_res)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
eed7a2fa50d689556c85e34a695f8787