From 2e4ad7391383ab0c21f7febaacd38ef5c00a543e Mon Sep 17 00:00:00 2001
From: "L.H.Byrne" <lhb1g20@srv02441.soton.ac.uk>
Date: Wed, 8 Mar 2023 14:29:52 +0000
Subject: [PATCH] working Hetero-GAT

---
 embeddings/__init__.py                        |   0
 .../ModuleEmbeddings.cpython-39.pyc           | Bin 0 -> 6468 bytes
 .../NextTagEmbedding.cpython-39.pyc           | Bin 0 -> 7159 bytes
 .../__pycache__/Visualize.cpython-39.pyc      | Bin 0 -> 6004 bytes
 .../__pycache__/custom_logger.cpython-38.pyc  | Bin 0 -> 688 bytes
 .../__pycache__/custom_logger.cpython-39.pyc  | Bin 0 -> 664 bytes
 embeddings/__pycache__/dataset.cpython-39.pyc | Bin 0 -> 8383 bytes
 .../dataset_in_memory.cpython-39.pyc          | Bin 0 -> 3325 bytes
 .../post_embedding_builder.cpython-39.pyc     | Bin 0 -> 10916 bytes
 .../static_graph_construction.cpython-39.pyc  | Bin 0 -> 7007 bytes
 .../__pycache__/unixcoder.cpython-39.pyc      | Bin 0 -> 8579 bytes
 embeddings/custom_logger.py                   |  16 +
 embeddings/dataset_in_memory.py               |  36 ++-
 embeddings/hetero_GAT.py                      | 179 ++++++-----
 embeddings/hetero_GAT.py.save                 | 297 ++++++++++++++++++
 15 files changed, 451 insertions(+), 77 deletions(-)
 create mode 100644 embeddings/__init__.py
 create mode 100644 embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc
 create mode 100644 embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc
 create mode 100644 embeddings/__pycache__/Visualize.cpython-39.pyc
 create mode 100644 embeddings/__pycache__/custom_logger.cpython-38.pyc
 create mode 100644 embeddings/__pycache__/custom_logger.cpython-39.pyc
 create mode 100644 embeddings/__pycache__/dataset.cpython-39.pyc
 create mode 100644 embeddings/__pycache__/dataset_in_memory.cpython-39.pyc
 create mode 100644 embeddings/__pycache__/post_embedding_builder.cpython-39.pyc
 create mode 100644 embeddings/__pycache__/static_graph_construction.cpython-39.pyc
 create mode 100644 embeddings/__pycache__/unixcoder.cpython-39.pyc
 create mode 100644 embeddings/custom_logger.py
 create mode 100644 embeddings/hetero_GAT.py.save

diff --git a/embeddings/__init__.py b/embeddings/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc b/embeddings/__pycache__/ModuleEmbeddings.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f517b2832d279eceff26690cf50123b46795f4a
GIT binary patch
literal 6468
zcmYe~<>g`k0givkauN&-k3no?%)-FH;K0DZQ2c{|fgyz<g)xUA3PLkRF@pI_QA}W(
zIf@xfvqZ6^Fr+Z$u!ca?u|=_i#W|ulQW#R0b2xLkqPQ4Aax6I<x!h6Qxja!kU_NUO
zZ!TXHADGRS!=Eb<B>-l#=LqHsMhSx1963U{!coGE45^G+BB_j7qA8rcjEoHK3@KbG
z+${_#+|5i;V(ttnJSn^_3@N;++*#ty%uy1lk|})IOhsQ(#ZvfF1X2WhnWCgpr4~r1
z2rXn}WJr}sl}(k(l51vcM)2ezycEV@22J6YpfK^%WW2?wrO9}U*C{ozv?MdFG$%N}
zv_O;T7E@l{E!O;klFVF?5;u^N;L_aO#G=aZqRf)iB2C6yEG300AX(mk{NfVV+@#c$
zl+3(zO~zZSIjLo-ImsXkkTEMP;CmPt7*ZKPfgQz^BAv>d%96q~hbe_Qg{6fxiZzub
zg>?>N3R?<$3riGRicBg?3g;X~aA2@Tv8OVp@XTRM;Z5OdVTs~MWlm*I;h)0<3V;@t
zD9%)t6rnkcDZ(itEi6%7sobe7U=^Zsm{Y`3#9LUJ8KZboSyCkCFs4YRNVTv;@qz<G
z_Li7$eoARhD%=+#MTwbtsYQM(nQn2U<|f4#XI7=&;!Q~`NlZ#CPK_@}EXhb_M)Dje
zQGxvD3`$r<3=9k<3=0@*7_ykMm=`kEGMBI{U@c+EVq3tzkfBH+g)xsQg=ryU4MRL@
z4f6ty6y_R+EY{f!DJ*lDL6KL(l)_NTSd>u0xqzz%qK+G^j<tqqHp5&duo{qVo)X?>
zh8l)=z7&RF22D1<TP(??DT%k3lM70#1OkfklT%YcQKC?uS(2ep#iF36pjXA~oS&DM
znp~2aqEM2rpvixWJtsdsJu@%;7H4{DiBEoddTP-v&iMG`oW$bd`1o5K@$q?yxvBB-
zw^&Q^i;^>Lv8AMzWhSQ<@iQ<m+~UbBj!!I0%*;tl%1OP&l9`v5e~Z1iFekGl)%X^B
zGQ_G|Oesl4ybKHsMVuf#+@JtWD$UGENiEW3xy4ePnv({R<^xHyf!vc?w34Ano`Hek
zmydo%er~FMPDYYpx{-l?Ze>b-a%paAUP-ZjdQoCQhDmWrVsdtVS!z*QPJX$5Dl8oq
z>!U?!v0gzX$TyjJnI-Y@;-H8U0L369AEOkb2qOn052FAh8)KC`nm%Y!(2I|U=+tDq
zC6EhFE%613nMK9%1=%^tOdyk>7?dtJ7#JA9X(0rh7HSw%7@L{;Wons97#A=><CHm^
zA%%g3frX)&nUNt+z>uL>$AEzm3?msB8A@16SW}o%m|GZ1*s_>i7(kh~85Ha6;MivI
zyTw$Ll*|aV6hwh61|`H`kl&;j7#PwSY8YaLY8h)7T^M3HYME-7CNdZD6tV<E!WJCB
zn#{MD^bBq>R^DQ)01H3}1yEFQ+2mvvmn7%s7TBpUFfe=uSyE++BYe_|@^j<UGILUk
z^=xwTlM{1_?eq}(H92mv<QJso-C`@qOwI<Sm7M&<lv{iypp2H8mktTbTPz?8Z!s6A
zmfT{_O)a^_0}bi2{N%(WP39t5P%2Uaxs5Rm;!sc`DT25b;Q~ca2(W>i%gDp1#aN|C
z(9JIwv@tShGT!1WPAtQczd0Eg7(j6f3gY5D;5c2tn8H}Y2+EHg3|UMInQED9m{XV}
z85S_NGo&$sOEZpzjI}JFc&uSrz_t(^gB=X<>@_S4m{V9&SXvksGS;%zu-33H;3(l-
z$e6;G%`C}K%bLzm%T~kc!Vt?-%U;7)!*0${%T}0D!`8tN&jl?Tc9d|}Fr~0JGxf{W
za+L5a;O$_@;;P|T$i&Eik*d}qrz&vT;qa>p4R-Z$bq-N*%1^0Oa0~MHRRERZ3gI5E
zL9PlKA&Ke53O=6Rt_td^1(hWk`FX19nhK77E(#i84N$4f+=Bd~5|EUFkEgF^h{7!v
zkTgo71SL>VObCObL<UslGNv%KFmy0vF)Uy#VX9$BVQywx$OuZ`!3>%#ekhtj*$Nco
z;1n$cOVM1Q6rIH|kqMrXS2EsWgQOHqrXo;MEm8v|6mZ%A6QBfAq|U&=pa~M!U|?XV
za>17-Qj(Apabj*#N+KvVV<hWaEIAdqIhxExAfMe5O;0U}hXhJ-d|75<d{Jt8YDE#q
z2cY^gqa-(HB_kva!(DZY!zKq}KEi1Rq&f{H&_E#$s*AvZ1}fr87{H|tLki<Ure;v(
z&*b+K6tbGkMY^D@#0jd<OY-w`if{2G=jSG6<|US7=I0e_GTve?$}dODXGJy)3=BH(
zU}TDpy~R|Xq{(-SsUYPRUr}meN_=r)PJCf$YEfm8G{}6`#DapHO0Zo;dLRMj;?mqA
z0}zWXvACqNAoUh=aY<1T*idjNXo2KdQc80RG&#ZHb&InozdSxCGdHuO$b^A`AxfYi
zzqlkmDL<t$zBsigGqpI150U_iq5L9ykTO{L0WPU;F=iGSf*cMGd}|O3lrC<ugA0z7
zBv6_umH?L;B1~e8LQH&&Ld-0T9E@CyY)qi~p-K*a#Aq_!;sO_F;PwJa0|Ar<KxH$y
zQs99WAT^9NObeNU88jLFG?{L(6(#1S<mcXED^AQU$Vr8kcDI;P(u#aQZUws<On_a@
z0Z|s8X9>z#p!j6KU3tOOYBGV%y2X}Ske{4UjN(L4?F%a7z)lolgtTE6GSo7sGt{z_
zFqME>DJ-BGD2p|Pv6pE9TMa`NJGd3YQNo$URl|_Qv4DFa1E|jGWvXQ@;i+NBVgWTj
zQ&>`1TNp}sA$(9y2enyh7#8r?Fk}hTFfL@OWh)V^VQL0dQ?qK>Q`l-bN`z|I!2G%x
zrdrNgt`gxaks5|9z8X$Rh7#c{(G+%s3u?K+c8e`!U}UIaSin)ky^yh%r-r$PXMy-a
zkZOodE=--xOtri<ybC12I$0M;f@zK#-UU(%8O#}KISN~9SQbdvaDZA|g(c9kQ<Lcy
z7q~tG$MP*+c#|VOB{Mhu&&rRNe3Js~H5qR)<rJZ0OjsHKXG}2$28K*f{md7`RLcme
zn`;;+BT6)Irqg7G6lPwa?7$64-{965SQbKn(*l=GN@g;oqIL%*6>tHEs}upJP^97w
zuC)kMa)8?JDWE*5DN_VWu(#O#e0+TJi;HjZqb2@Z%sHuf;A{+Tz1?Ce$u9!=>K1Em
zeoAUi5vUBh#T@MJa*L}Vu_!S&wIsEu_!etHQD$DrEzZ=u(%jUd#FA8a9bb|kpIK29
z0CGR5B)G*|l$e*ES_DdxMWUb(U@J+@E6y*v#gdZ`YOYqL7UjpM7bT|LVo%FSEGbFN
zyTy@|n4DdnSd?;$rMM)u;1)|}Nowvbw#0&h)Vvf;K5$V2&HzP$AOm@G@{5b((@OJ_
zLFL^oHV`+p_!cL)-;-IDT6BvYl5vV}v4S(qEv}OMlEj>NkisaCH{%l_R8ka2ZhlI9
zVti)GEzX?$^!S3J{G?)#ON)yk@r1~@tl$U$m7>L<s)~V$gOP<vj!A@xi%E=$kBNto
zg;9V}j**Lzi&2A-gOP)gkFiPyE$m^9Qm|2)jJG(7O7p-e1EVDhiUn|vHUbwqur^CG
zILT!}`zTo~pdJV~=Q1-g)G&csD~t>&EWr$#tW`x^3Qz#<`YFKEJE#K@0*;cT{KTRZ
zg@U5|tkmR^e1rx_gh6C=b#wC5Q!<Nem5NI9iuH?1^Bjx}jEoEo4UBXR&5aFAER^6D
zRk4EAz{*>N<b2)y<mA$#qSU<PRD=UGnQn39rj{gv`enD6OA^!lG}(*T!7-bdk`fPh
zXi*GEj1$hk#a5o0nVwMsPQ15RlXLQmVa07x2}mhhIk=OH9&>yp`SD0z3I`Rupt78S
zm5&h=ZA|~zn5$HAMHPx(P3Bu%;A{f!OBMxzQaz|#Kq-+yaR*BN;P^`e#~-N5Xa=`6
zL9M4O<{E|)rUfi1j0+i4n0moI!xGjSh6QY(_UA&T6jpGrgQbLh0doz@LdF{Acn*yA
zVFYsfu!b4jlPm(Y8#Ebj@#H2J#DiM%pbSz4a&c)+s%~m-Qk4iOy@RuCJfvU(i54k?
znpm7gk)S~22et6yi%SwqQsY5o3^)@PX@h(UDkh3RHT^C2qSS(%#N^Z>P?cYl2-414
zR0~oA4-#<fAywRnoWupT0UV4;ptt}J{4jAbN->KtYBAw#^&@m?GJ@^G$TaXk2RQ=P
zm0!TTkO5R#fx2d(Y$VQ*!US%BGnKGHDnn3hj->(af*io$N`%=j3KW*{knk&V1bLeS
z96q;LKphaUmv3<)g&iocAYtSI3LJ1iLAyiX&_c9pkb(zf6exC5b8<n!18RmdFbXjV
z;O-D1bb#9=MaB#a41St?MIxXG<%$OxR+<AE4B&~6FDy;WfesBYXXcd@fomR6&#MSj
z)D(Gx+zBqti$Sa!kO@2>A7v)zrj}&nrxZ1U#9KfF*v$w6RHqejGcYiKiv40xyPtze
zh>MR^K$wqNK$?%6iBW_LERQk91gc^|4Y*>E6JTRZ;Qkn>*TE6R49XWR3{fnptjMEG
z?F=joQS8ACnjGMyi;@XJMuI9r5M2x^wL#748ir=31x%pkLJ9K%7Eq>K$OIbBDPdl~
zR>K5pXxA{rvx9h`lmeFH0LMI&-^&J2Y;)dXEiNrcEdtwD1P(QjCAT=?Lr_H(pb%s8
z$;?YlEV{*>1EvjO3?pds5K(fhWGn*pu^?p@I0QkamVqK3l$;qD`M|JB0<C<44{V_1
z0#MNd!r-C@<m(!S1)$W=2<rAOWUOT>VOqdk!?1t_G?)Pz;;La-z*++t>|)Ad19gQ{
zn0uLOS!zI)2#cR4(@Rk6;^qJU|Nj?F0EGbKE#{QWTupX}KUvB$Q_G7$DfAXgQEE<U
zQ7<T6^nrqa8&sbb=ckqACRS*&g6r^GY?*llr6t9;*uZ6a@h#^3(vqS=kQUTbbc;PL
zzX;SI042X-aQVO}z$n1T1*%lV(L(~Nw#bZufdM5nK&@*~^ngRdiGhKkh9QdqG)Px0
z<i(J}*v#0(SOcm`8Oj)o3?QAYEaohhELPBfz(U4gka}itCCgC5D9IoR8m?pZt5SeC
z0921==B1=oDC8y<6lCV5E9B*uC={obC_+Mt4O}iliU(FutEU(oW<`~t&|xjkNGwRz
zWCDjvkvWzi5CQc{A!Rh8N^C}?3n4}>MiIs;6^LWt=}!-vHcbw2QBc$hvZoD1fD&2J
zOb`p4|L1_*#0wohOis<oiI0b*NN~LLgVcf3Ca8u0Rkg*SDCc0}VdG#2X9ds*U2bA#
zUVQvZIZ!WMSWizMWrR;JC8-Eh=T<4|>FI+ie*L1va($@xz}>uf!`#dqy@Ko<O@$&m
zkcEOE0$jiNf>@FuC$g1Pf<pQhTR~!8N@8&lC{iII%$!tgQUuEE;N~y56&oc09{SXS
zwG<O`ASqQGB3xRMnNzHX)Gz?ILBI{yTVmjTDmaybhGOHP<DSrw(4@rT%w*^MytK^p
zTP&V_ZvNm>?G_s(D!>5(32%Ln!$D<QQ9Oth1tLI=WN-@xDS*MD3<}^|95#?Ju>)nh
mVvz4aO-U$b0@0xHP98=cMjj>}CO&4qcp)(kMjj?m&Hw=8hEQDq

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc b/embeddings/__pycache__/NextTagEmbedding.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5bc5968612d8fbb03504a3b43e0a545dab411d0d
GIT binary patch
literal 7159
zcmYe~<>g`k0givksnQG#k3no?%*MdL;K0DZQ2c>`fgyz<g)xUA3PLkRF@pI_QA}W(
zIf@xfvqZ6^Fr+Z$u!ca?u|=_i#W|ulQW#R0b2xH2qd0T9qPQ4A@+>*rxja!kU^Z(G
zZ!TXHADGRS!=K9^#m~r)%9tgP%9tgX!rsfs$l%V9!jZz+!jQt*%oHW$&XB^D!rj7<
z!kx;QCEUy$C6X$d!jsKZ^d?m(g*SySg+HA!MIc46mnBLpRcwLyLIy^LREbo{RIx0n
zX2xblI8PeN6I&pYB9tP$kcp8Yg)x{xQ{*KmEc`SXZ!v0VGTmaz%e%#zUr>^n3leeD
zWW2>2T$-DkSX3Ealv$Emq{(=TrKB(gB+HtUT9%rV4APE_;la6sfq@~F0ThH$OeylI
z%&E+&EGbNLm{XWjSXx-4SX0?jStJ=!nN!(P*(DiLSm!XOu%)oKutagBvZS)4aLi#!
z;Y{IbVTs~Q;ZEUcVTj^NWr65P;hn<-@_GwP6n83fiohJk6u}gs7M3U;kS!^~bC^;@
zQbb!=qIgqTQpDykriiCVw6H|+rShk;fK;S1r%29WNs&sCZeeL=j1ov?Ns*bum?E1Z
z*TNDd2o4s7TjGAH6(u2w>8`m+sVOO$dFdfViJ5t+MShyhxA>Ct^GZ@HO5$@;^U_N)
zZgHgMCdC(LR;AwJO-U?COiC<HjW0+n$w+2J3K>wqfcVa!v{uK!z)-@lfU$-liz$nF
zA!99T3CjZ35|%8s1?&qMiWE{9^O#bY7Bbc_#Ix40F5pOEu3^Yxoz0NKGM9NFBO^l%
zOA13NV^Km0=K`)87O*;QusYTnme~w*nZRlo7#V6n7-R}h4O0z6JZ}wi4MRL%3PUi1
zCY#?amgLfu#9PeC1*KI20Y&-Asi`23DwJoIWGGazDCjBZRq;CK=jEj)m!zgBl;kUD
z^50_5$xlzu%uBz;nVwqWlb@cRT6BvuK0Y}ovA8%s{uW1kd|qO1YJB`H){^|9<cwQv
zDXC?d$*D!W3=9mncruIQ6U!1aa}twsQg5+j=B4G|VlOVt$t+1VzQvvlvFa96N>ULh
z&lYim^zf!uB<GZ-q^87|B&HW@vfW}SPR&U}C@7L*U|=ZX0V!nzxh=J5B}0)c0|Ub^
zAN`E{+*JLXj3mQ!BLn^1%9Q-%(%jU%l4AYzqQrs>lj4%Z<m~*i)S|SU{Br$NSWFk|
zqsP5oK_$p%nR%Hd@$n*{2owXwEF&MI6r%_u2O|%o03!<{3nLpN8)KCMx^8I7(Tk6V
z7@)~~iv{GDTii(L=oVX1YDsBPUNSStWGIHE8IV|U8v_GFJ3|^{3S$aW3r7v33qvzw
zEprJ&4RbSN3UfANky1KCFhe>+ElUlv3qx#JCu2Kf8e<9zST!?PHERh&4Qn%F4O2El
zQA7>v0>%_pP)aFufQz#f71gk&Fr=`xGD$MjvKB_vuq<G#VeMc{1E;A%ts0hK22FOq
zWJahHK@<}M0|N^K1A{OqWI)l8&QQY;D^$x^!`Q))#Zbmr#8SgJk*Sa+7!os@jNnLH
z$#jcJ&)^ng$x6l|X;37BVy#FH6q8&wIhn;J$@#ejb}9@E44*-^RGHxl*rd|ToRs+F
zr2KL{o1FaQ#GGO~J%j--K~BZ(UI_*U2C#btAnwg(MsqDVF%>B@FfeE`-eN0BEJ{x;
zfrx>eTBOdvz@P(iy#~nP%mI!rFaQ7l|6h~o7IRK&-YvGof`Zh%6iqgW7ZpG%6&V;9
zqIlCXb4pT+py~P+Ye8aWQ873ZfQlBd1)59{O<EvLAj@uXA-t*t3T%+$8JL8aco<oj
zSeQ5%K`b6d9!5S!0j4S?!r`UKbc?SbzqkaPx4>oAEzXj}^!T#;<isSDq{zv@zyK<~
zz)7(JTzq9Q)-Zt*V=q%Ja|u%ka}9F}vm^tkqDf&%XB21XV2Eb{7jldX8EaW;SW;Le
z85XdDlQ5{H=KvLZHK5{+t%Q9cV+uz$vm`?;TRKB6dkvclLo81%M-6)ohdDzndtpKi
zdj~^22e{zk^aIsP{Mo6g1@S4VX^EvdCGmNQQS2d!>3)g1sYTkLcwh-hOfN=>bWkb)
zc>tV9L3y}?VF5!2xC~whN~TN;nSvQsGWuyU7a4*~;44ThDozCjOgtzpf(x9LjJKGJ
z^2>|VKoNpStsFKviMdHBiFVc?aZoy_GQk!0X+`<D@oAYksl~`?6{7tYQ*^8*=Pjm!
zlv^A{sfj7^$;D;2SU|SlVlGZCDKZARi#4&JAgA&cTVioZWkKpK=Hil~Tb%GX)Z~OD
z86%KOK<cB|K=GN9Rs^bdZgGP+AOqu5(ryXCil4+hP~i$n*O1tUBt^zda3X|w7Lxi9
zjyD6PJ|R%HWaML%VB%sFV-#XyVdP-sV#J;95ZX~<2Bnm%0w=8nj49AK0cFY*CP{_`
z%pD9_EDITHnMzno*lL(S^<)Y&D4Bs094K+rFfHH&CoYy0))r9WVy<DXVP3#h!o83&
zg$<mznBk=(Pc3T=OAV_zLoG{TQVmN7Lp%>Sak2YV`31ZBxH^X@fKrfxTadr6LO_0T
zi9)!CYmh6L8B$r0>X~9|sNm@50#W7T>FuhZu3Au8l98XMs;=PU>FXJyU}yjWMW&#r
z03}KkPlJ*pyv*cdU;rgWh8l(rhAc)<LM#FWuqG2YDT0$;kp;*Y)Pz@L#lXPe2@(gV
zIYYc@E+q*w!4-iLtR`=f4Jcvpfs<TuVNQHuX=+hrkt9fvIXAJO$Q)#a9f+_8ixrpV
zg41FVsNG!T2$FCD5uhXnE9gM%A{UUTD~NCd5$+%YoNd7V0H-N-aH2>_0;w;CB`5(V
z2__*#f|4iTF-@jh9I%941S%JkK{XO6ae^?24Jx(ZHq?NtP*4T8fFXr>A*dP5;#Z^z
zavZ22uVS*Z(PX^ET3nEmSpqA9eL=Edr-KQQ)4{<K2+9W_8yFa?)NnZywItDGyv0=l
zYD|Kgj3|XH$eEx-0d^{=NdqZ9Yrtg_G)dfID@x2u$<MvTR-Bkykdvy(49@Ylm{QV;
z{4rg5ivyx8J}(5xsW{sxFvXfoV8d>)r55BTXB4Bj5|mmw7#J8pt}T{hgtTNAGSq@|
z7pNit<<w?S{$=f8$YM)j>}6iS-ocQ?0V-$XIZL>*xH}lKI2Z6NWJqC3VeVzBWlaI)
zS%wl`h)z&jrGp`!uY+L$e+NUBKn>$UrdqZV!5XG!Mi++IRkiFXY_%LELN#n)eqRhz
zEoUuPiEx$(IRCS!aI|pLaDq%J5zZ1#;p}CC*<TB>Sq#)qDiN>Y2KA6S7#48Wa4%%6
z<*8w=;aMQDkO5>;2SXNL4VNT{jbcJGQ!Q@|?*d7X2{o(>q`)*s4etW!g$(8lwH$?Y
zH7pBcYB<2{syuL+!UgIxaDl5-aFpNTg*P?hQ!;aJ@xaPqP!}To&&rRNe3Js~H5qR)
z<rJZm8K3|MC68cGOoCdVnV@zSUkp<%V=Yq+Qw`%}L>miSc55;hftpW6pp18mxg;?i
zED9k&N%$6*O-g1mq@9%jN;~+=5pb$SDn;OWi$J9ric>*Z0_0S1Z3}WLxQWGA3rZCX
zHH?xBD9u4gwNsKGpIHHJXhG^6L{0;>vC>I*AE=@Ow;@5<IR#W*YRVRY0`V5RpO24E
zesS?F0rWguqzkeKRC9v6AK<e07E?)n5va*<i#0btB{ioA<g;7M!R{`%xC#=B5_3~a
zQj3ajai->#=B5@UmZTO%gA4&h6QtT@ElSKwPc32xiGo7p7F$VbUU7cWEtZ^oP*<ud
zwJ1M6y(lr|7HdIKW?soH_OzVDl9JTCTO3J=$=T(JMJcyfic3-pZn0#Rq~;dsff}QH
z;Cv6xrbVE<eTz3IzqmL)tu!yWBr`wn78{70T6~KW+@H^^N-es@4$0odw^+g1I*O|#
zza%jy9;C3y0OUAOOEU`Iy@fiGGbcYizMv>SsTkzm;$lepK@{Pv;B*lVN|d0~#=ykE
z$ik$=B*G-d#0P4bG72!NF>-<11{#bUj2w)7j8(Ge;STF0fz8rnyv0#eng{MmqLk&J
z$OARvzy&0zjZ?yq!dSxy>hd;&^FS7~OP$36>M(<g7iLC=8m0wopp2RV>b$U46>%v*
z0eDnH0iGQ}ecTXm%p~O}7NsZ@6y;~7CYR(RG(ds`BCD&Llb@cFS!AnJRGL?;UsRgs
zU}RuqWN2t$q-$tyY+z!c1h=S)6|AO82;^_wki>L^rJ&j?H?<@YG;~lD$-uzir^#Lf
zY7<5ACZ?ps!yQ)y%AvP7;rv@{<*AwJ8L(zya!!6RtTHGn1=Z7R<=`<JNOZw7(=EP|
z{CFg96oTRnRM#-D@-gx-vM@6JV`Ii$m!WFbWVyu!&Jy6hLs2Lw4TXUSQ2#v{lubZ2
zFbIP@1gbQ^F_!^u<kc|NFx4<OGlTjD3m8k7vY2ZaN|+X~q%baIOkwH;cQ{K}YZw-=
zfyRs$GNrIKGxf{VvX-zfV6I_Z$XLS?&k@d$!ob47!qCjj$dD&s$WR<*z`zKGk)WXu
zP<*rb6@h9jO~zY1xrqhwpbjD^e^jxTB&O@8<|b7Mg4zz?!Uxpa2MOI`DoWDiEQ$k#
zB7aVPVoH2*Nn%NAJgC4c0%u?rkSL@~V975?&AY{3lv<FJn4EfxtspZwI|ot(XtEX6
zfI<T`Tp$t2n1&uPTp%;!!NCdARSXJ41|}XxDP|EyEhgN}YJ@gTMzDPtSq2{PLEwOg
zj5{r0UdRBdSs58hSU{OboFRp&nW<k6)WTlCR>P3R4r+HWGGGL;3vv*HYdB`VC@w@G
z-eM|G(&T`KQ*mNhD#XKFNMYCr3YBJ1@Uo<o<`#g<OK3v|?1Q2fkOFX;1rw0U3S<@}
zeA+>Y7L=qI7=@Sw@N_v4S~Qvb{4|A%KuyeBT=5`dN^?NtR6OzVg{6r(&`~Sq%)F8!
zaAQ^;>^h&!;*wjeA*BU5sgSm~2gqSQAOhSf2AA<AAZ|5CJrBr_naR1SB^miCMXez5
zE)W6s1cCs0rwG)e2DzpfR6TMqiSS8@2=H?7i7@jq3rO*CGck&Bfz@CPv4RRgP?jnN
z6@aiIR&dJ)JTepo8DeFKVo7C9VV%PW9ad#)XJBE7Vh?7}<NzmZl-vo*sUQq8x){`|
zFJUNQtYK(oTEGM<H%piouz+&wLeL;0YYFoLwi@OdCJ}}dMz9DwNCcFGz&uDVoyqTI
z1E`kbyv15vT98@<wzCKns6`-4ZgIi~R*T9(A;;#EnU|Vabc;O)OdG-&Mw%>;dJd6?
zKsC)u#v(y@z=DGlWb_13vH&G^21Y(cK1LzNDoIdzi#&`AAB987B%nG4gu$f~s5ex@
zumF@f7(pFTaHo@L0doz*0+t%4g`m+45Sz6IG;qsY!j#2U!wjxzm_Q|2FoPzGpC%Kc
zKLKWGGTvfN$;{PchxnGIEHkyd2xQ|emZH?0(xRy#&rAaa1UIO*EzVCX$xW=#WCi>E
z7F%XsL1{^G6dSnkFTTZ`Us_UB1k!?<rf#vP<rjfk3m}sq<p-kxqW~ipsLGYV2nwiT
zO-69FgAyR1_8TakzySj4i`IZDR)#W$Vj(Yv6vk%8CNLi~3SF$?1u8ipB8*u~H4Itc
z7M>(%l$qJDN(@xn<>i+sWacS=8^NVHiA9x)Rk9%Ql6-~Cyp+@mh1|q~g3P>h1(3qx
z)DlHaW^kbZ4z425coVEA2C5;^OHRS`RPYcMqOzO|N(s;c0+bFg#=`V4v}kgG3x}d^
zP`LJh2vCwLng?QmON9BL=;4Kqbtk9h<iy7#Ib=FW9XKU|s*OIF{Txg@Y#i+1)DId$
z&P~kBi;sT^3f`9>#kbfXg9cT?dV2aO<I#F4Nt!I+<|m{f3K~}h4|8X}U}j)|Br4fk
zoS<2mlKlLf;v#;K13+bJQ2>Y~4stD9NhK&$Z?P35=A|SSgFE`*f*G6%!R^>60q{V*
z9<1G%m;+9iQQ{Ec(vr-aVm+kx0=Q`cZpz-`PD(7!Om@!COUq2Z#p3Db<_|8gZn631
zr<CTTf;|lm(<tVW)RH1_sR^q3A#H3>(t->lfXh!rN(QsQfx}^w3pU9P)W0YObxuId
eM<`|j(E^O1;Z7bVK4!ieAyE!SA#n4Oj}rh9^3o&#

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/Visualize.cpython-39.pyc b/embeddings/__pycache__/Visualize.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74ac8f2a2d5b9996aef84b0c816ad0946d3c1e4d
GIT binary patch
literal 6004
zcmYe~<>g`k0givk-3km0k3no?%*?>R;K0DZP@KoWz>va_!kEL5%NWJT2x2qkFhwz^
zFr+Z$u;j8vv4O>ybJ%k^qBwFnqd0T9qPTLoqquW<qIhz7qj(t^QaPgdQn^$4Qu)%D
zBpIUkQ+ZMaQU$?mA$NupmK4?&h7{IjrYPYQreFq5wwEB6Xfod7PDxEmEX^rN$xJTs
z(`38F?3h=1i^V6ixa1a#OJ;J(E!L3If}GS_tf6_C`FXcEoD*|$5|eULlR^5Cu@KBR
zEsP8dsSHt!DGX6eDMG2tSuClnscfn2sT`@Cshl7eL0GArY0Q!gAU+7EvV&D|f!W+(
z5=A7H3#1xV6+tnGZQLm8U}9j~q2{rtazR`MwwF5<!Dp6afU3c!59(T|`J7;tJP?yo
zxgq|8ilM1YVVc9A!kog=!WzY!%9qNX%AX~W%AW;}FG+?J);WwRY$@z5EKx$Ja$p&b
z6wVfwDB%?06s{ER7RD%%6p>W^EK#tWK#E`sYm`_DV=#lJ=q(}lqQruXu*~Ap#GK5k
z#FEVXJU>nTTa50vm<#fYZ*k`3r=-T0q*j#N5`r`13-XIIK`L*FAjH!0^Gf2A^K<fx
zZi%9arX}WP=2YGiMiVa1tV)Gi0A_N5<sn80g4s#=MJcI8@#UE*B^kF+gdnPTz^d{K
z5|cAaDsOS6rlhAr6+_riRbX+bs$>DA6bWMUGcYiKvRUyp1_p)_hIWQD#uUaBrWTGG
zMi+)=#%9JMxfJFahIqz2C8#V5imXivD_EALgsF+Kgt>$T6fG$nDV#M7S&T(tDO`C>
zDcmVMDZD9sDg0outQ3JfrWD>3!4&=!p<dn^hIrN*#u|orwho3GhIsZGkZl|_EHw=A
zoHZb`xoX&I7~;8W*lQT#d1^Rn7~*+rIBOW<`D(ao7~=VBxN8{V1!{O|7~%zMK&A=R
z@YOKH3)k@1FvN?bFa$Gb3i~B9B6*vMfq{XIfq}soRI(^BFfi0G#7fmNmN0ZMEMTl*
zOlMfg)XbQ|kj)gvz{pU@1WqYTFG0HeRx;gU&PmO?#hjj6vXbc*lb*pX#<G=+McfPw
z3@aIax$0-+=celCWF#4;8yV>5R;J`9m*%GCl@#lzgVLEvaY<rwc79oEQCd!ZxqfPH
zQff*{W?p)+J}d>M>J?Pp;<QOm%}cE)D6&&zU|{$Ra(a~>W<u4AkI&4@EQycTv&qR%
zPRuE`(}Qbtqk@k?$-=Klih+TlNESpWfC!L_ia>!@q|Cs;0K!Eo3=9k*RlK>GDJeOr
z3dyN?C8<SK?8*jUaEsNkC^IqVB_9I=!%I+U{}NOTR<SFmC4*p*Bm)CO6}z%A7^Kgd
z2?8(eU$Q`iEet?VledTmWH&E}02x#SO0`9x<Xa>N5(61oBn)DSfCy0#AqFDEL4*VY
z11PIAC+6hbV$Cd$H%__56_%J&n(A6qlwWj<1yr!#VlFNzx+MTD>f?)3OG-fTQUr1_
zICd#_fxs=6;?$fpDjHb?jtp?LfMVnp2P{E=!l>8{lx~E@7&#cl7&(|&n3)(6kO$5N
zvp#bOvN7U@RnmA8pC;oiwxZOM(xN<+BnOHsP~bU(GMN)2sB9@=Y-X%wtYOGvTEM)J
z!G$5#sFtaQVF61DV-3?nrdnn&n+eQjsbQ>Pu3<@IVr7tIs9{QHG-qIDuwf_^N@u8L
ztzlfi3fBKBhAD=rmbsRtmMxtD1eqC_7;4#TIchm;*jyN5OJbO6xoW{BM;2QRmn1_9
zvm`?eLl#>NH-w$Skj`w*P|FS~RiG+ZkW{cpGC);u)G%bR7ZriZ9jH21Bz3HCb)1Do
zB^)*EH5{N)u!IBThZK%ph%Y%8GNy3#veoi{O0_KZ35-P=C7fAYX$&dcHJq7@H9Sig
z`xx{1N;tE)(-<-tAbgO|Ygr3}YFL9AG<p1rzzG_hv_U1>OArChLbq6RGV@Z4Z?WX%
z=cQ^g-C~85{jA`+qX?V>*lsaa++wV}#gdz!l6s3X15^xU=B4G|VgYHr#SX0@ZZTGA
z@)UteJCH<i2{^spVg*$q#kY9k(^E@ex#JdVW=U#pF(iBFfO4810|Ub?ru6(<oWaS7
zB_*jv#woYhp-fHQB14cgTY6D`X+iNV7LbWW<{)hr3=9lWY$ZjB$*ILfjvx_B5Mc!(
z>_LP)NF!tEEw+NhqQu-{NV-=9$$-<n0N7O^^Wq_fc!3ff$T18|987$SJd8Yy63i@&
zB8&ozJd8?=N{nobDolvnz{SY%i;dHSsY)4JItTd}$?RlM)d9*{APg#&K;;fNmCG?Q
zfKy>Q12`SBfl^`(Ll(mVMsR9%1Eoq(>aXFHWJqC@WB|u48>n8(VgidXLB-f>7_yj)
zq`;{XBEt-o;V6_UVX0xOVNYSHVJs?2Va;PoVH0OaVJ~5+VXk2UrK|<4DI5zKYZw-=
zrf@=OE-1~tkg<d%g{OuEQXBWOqzLqKfK>{DWWX#Ts9IqtEwYd?MYI=`H(67}pkm?+
z8B-*BIcvE<?wr6_#8twU#h%8HB3Z+c$ymb$O~-6m9N=^e=GSu9a0fGJO8FImldK6S
z?m$KVE%w~RqU_Y7;wn~ca5IbH77K_j0vGAv!W|U8MWAy1C8*TD#gdbnR&tBAC^J2y
zM3dzfXGun8a&}&7aq%tIlFX8v)F}3f#JuE;{GupMFuyo6B~_E-7H4rretB_nVovHU
zE^yv}uz8A7%TkMqQ^7oNYP`h)X^nw01*A0<#RX=Ab=~3sQ%Q+MkO~1*(jsajO=gG@
zY>*~jQ8>t(ppvi%R3&M079rBLHAu<^l&;wzAyDK8$^{%C&x6{MQ9K|Ptfg6G2T~5o
z3!0qJ3<+xcq369KCy+815a9|UKvjSy4<z4#10W|cDK)1k6eJe{vYQX)n|N?o6@`Ih
z5RL>{SOm&^MW7sYixr%^&~qO@IQM~^1Ic|^pv(tS&%h|a$j2zcD8k6bC<e`LeDI2g
ziID>e;TiD{%MTWQ4o0ZVUlvvlMvkX!aypDvO1Ls6!rEj|a|M(qK^Vklg|%!yftz@s
zd|1P{kg=1ogsB76u9fIu%;Rcjgtd1&K{W$2r~$Z;v6HEWDTSewp=d(~gieF%NdUKx
zLEREijlcq_hng9K89Eq)84BN(Fm*6BGZwYAGo>-6u%@uIps6Zh>tHHjZ)Pf5)y~k)
zl*W|8hN-55DVQP84r=d#4u%>gkoi!PFLW?L`7o1SbucYptYLzvTgVj50MZX`^|GV3
zdO`UPghAP<7}UJ#05y#n{lL+w$p}uGnvA#j(V9SDX>fz@7KcqvVs27OqTN(bDgm|n
zstj?(3$y`+)G~$YMQJ2}Y8jAIz-|I%@pMpIoTrwtgQ148nQ<ahAxkg=#0B8e9@Grg
zWW2?ZlA0EuR+^U#kpkBNTsAqG#U;u4xdnE!;m$KA!Fe#<)Nmj>)cXkc6{RpRFwBQJ
z48sm~1_lOjAjmL)>W~_SSdm)B4#ouxHPE0~$b=Lanv6vtb>LRHCQ}h;prI%gWHvaE
zK*0gFZXrkr6iu2OMFt=?NYgEjWa#(=sJRGg`o_nDOO&D*P=UaeS)7?yT#}fVoLU5`
zKsA}crAHJG)JO4oiMgpo9w5^|mT59WY)=FUfSmxg9aIcW1{DJupv(xWeHnR}AlY4l
ziH(tiQG$_+QHqHTD#pRc#VEj3B~Ku>LT<4|=D6oMmAV!AGB7ZpRK%bT7^pws4YG9w
zxGM|l#-=cVA#(~-3QI40Eqe(=4Lb;?u=X<7a+ENDnrSIaDI7J-MP@0Sc}yu>Da<L{
zDNHFmB@Ez7k}rinMWB}@9iopTMX;Bzma~R)0b>o@LPkc05~eKX5~eH`a7D)ns`yfb
zdKqiEYPc4#)-Wt!TgU)XRl~J_y@p`{M-5{Q6R1g8!@7_aA`b4HhCuwI$$pD7v9u(=
zC^0WRHHrgNU!@mirbKa6r55FbY9NRhn8jX}S)7>!8gwejOwKNj;sdD$F~J3Vu_p5^
z){4Z6%;H<Dl^`mLwF*R4i58_NC2AO%n(827%_?57uz?O3YGzJgWMIfV1R}~A85mxI
zg7p@Aa!!76YH`U+P^k*8!(W2(t0v1W#-v+}Iky;#ZZVc<@<s6$Bo?F=#V4hME0!pp
zf}H#kIFAG5J#g*HmXQjozl%UEzoKAJ#m5J(fIu~0JV-?p7sy9Bsp+YCDN&HND!ADS
zsw~qp^KNk`=H=y=fX8c!i*i7zFBg=2S&NfX^HM=wlAOfK{L&IlPH-x|#a3F7l30>j
z1gaWuvAJcYmlmboV)Fs3iQ-C3Nr6;d#YLd>6U7M^C`thJ9l61IpePn(4ZQ9wiUV=M
z6(*>%i{b@^Fsx}@TvQ426feXT@sMWgE#|b$bV!i^$~#4%ta6JtxhOTUBsD$_?51^~
z<O`~`893p!BBaLr$p#bo%f^O~<ATUCGBLty)=z9gT+9-Te2fx|RSH<L6pDqKyne_)
zlhIF;HKZt$fq?;%DBU#0LFEJ}*T=`-;);*Y%}*)KNsW)c#S<T2Selpvm0`-OxWx*#
ztO%4sZn1@>7L}w{6oCr1B2bI@7PDVzZjnAH%s4<X#p0Tho?2uC;)7}caLg2ed|m`@
zxq#vU9QcqXO$|t0J%|8@9D)FaVG$<-0|Tg)RUE{?z`()Ez=#8WU}9onRACeXw@E<l
z5)nofFq;WNe&Df@Q;^c+l7;L4#l*q#18lyg$Sv;V{G6QB<dV$%ykbZPi6bwyq&&YU
zyP~KW6x|}A)SgqR2OjH*&rixqO)e>p;zi+tJ$j3+q_QA0FTDuVPALNW9~5fffPw@W
zD2$3)K^|@gg%)#WUdb)iw4D6JlA<<{D9Fo5NeLVvps={bVFQUFJCM=EpynwLBL^c7
MqXH8Ps2{`(081%^1ONa4

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/custom_logger.cpython-38.pyc b/embeddings/__pycache__/custom_logger.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a06a90bb35dfdd823d5608f747aeabf6a251fefe
GIT binary patch
literal 688
zcmWIL<>g{vU|^{F^C$T)BLl-@5C<7EF)%PVFfcF_%P=r7q%fo~<}gG-XvQeU6s8pB
z7KSLM6vki%O_rA+eSXOx9UvA6vobI+urM$%ID<@xU|?V<VaQ@k0qg5!s%0u+$YLsC
z&SFVn>1C{Cu3=ulTEn!Ek&&T<VF6nW!$QVdmK4?+mU#9WmIWL&j3AjBmMqQ%TniZ(
z8ERM-aMv(H#DW<#+5D>1RW%ZelS?vlQ#FefR5fx^%Tjal5_3~EbuB>r+|=UY#Pn3n
z;wnMaNL|%jUDXr?RSzpwUn|vMO{QDSX}Kl0*i#ZqQqyuvG#PKP6_*sHCg!?n^4wz2
z$xlzu%uBz;>6TxVn^;nkT6Bvy7%c6Pn3s~1T6BxKxUv|eB_+SK<Q7kHY6(ImXL@Rh
zPkwrOYSAsuf};F_#Pr0H)LR_IsU<!j7vAE?fS6Hyi#0hXHL>UxS7J&E)Lc#0TPz?i
z++qd$<`!of%mGn+#mPmfsd@2G-M834j!Z3D$xtN9z`*b;F(xK0GbgpUB)=$CuQ<OX
zKTj_)S+6uZrnES<s5mAkBgrt`$RH-SG9^E`G&eP`q&OzMD6t^Jq_`w8IXk~BwJ0qo
zzdR;2Hz_qGB{MI*I3~HYxFkO}9%7ANLFFys;?$DTf_M~hP=ptYfuf#~jgjd;8xtQR
z7b6!V8zUDJh%ORmU|`T>E&?el;$UE401JR!j6{I+-{P>z%}*)KNws4Hr5#Wr;$Y-p
F1^~Vu!!ZB=

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/custom_logger.cpython-39.pyc b/embeddings/__pycache__/custom_logger.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97fd1d8280827875e666c1a6c8b9995ee9958120
GIT binary patch
literal 664
zcmYe~<>g{vU|^{F^C$T)BLl-@5C<7EF)%PVFfcF_%P=r7q%fo~<}gG-XvQeU6s8pB
z7KSLM6vki%O_rA+eSXOx9UvA6vobI+urM$%ID<@xU|?V<VaQ@k0qg5!s%0u+$YLsC
z&SFVn>1C{Cu3=ulTEn!Ek&&T<VF6nW!$QVdmK4?+mU#9WmIWL&j3AjBmMqQ%TniZ(
z8ERM-aMv(H#DW<#+5D>1RW%ZelS?vlQ#FefR5fx^%Tjal5_3~EbuB>r+|=UY#Pn3n
z;wnMaNL|%jUDXr?RSzpwUn|vMO{QDSX}Kl0*i#ZqQqyuvG#PKP6_*sHCg!?n^4wz2
z$xlzu%uBz;>6TxVn^;nkT6Bvy7%c6Pn3s~1T6BxKxUv|eB_+SK<Q7kHY6(ImXL@Rh
zPkwrOYSAsuf};F_#Pr0H)LR_IsU<!j7vAE?fS6Hyi#0hXHL>UxS7J&E)Lc#0TPz?i
z++qd$<`!of%mGn+#mPmfsd@2G-M834j!Z3D$xtN9z`*d!Q$Hg=H&s6;Bgrt`$Ur~0
zG9^E`G&eP`q*y<_D6t^Jq_`w8IXk~BwJ0qozg#~xHz_qGB{MI*SU<V6xFkO}9%7tc
zLFFys;?$DTf_M~hP{bCCfufj^jgjd;8xtQR7b6!V8zUDJh%ORmU|`T>E&?el;$UE4
h01JR!hD3n$-{P>z%}*)KNws4Hr4dkK;b7!o1^~@txTpXC

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/dataset.cpython-39.pyc b/embeddings/__pycache__/dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd97a912b458d313f1ee3b398718af41bcc1a946
GIT binary patch
literal 8383
zcmYe~<>g`k0givk`6>(yk3no?%*w#P;K0DZP~61Gz>va_!kEJl1)&+En81AIC}uFt
z62$_hS)*7}7*d#W*mBvU*mF6eIC43oI2j>&xT3hgYPh3#z%*|ZZwf;Sa}HlFe-uAh
z4NHzdu3(g4u27Uvu5grau1J&!Se!LSG*>K249sTB5ziHm5(l#xb0nf9QW#R$b0l-6
zqNErZQYEsaQ#exjQ#gB>z$&<MWO8MrWWj9i9JyTiD0xPPRH+o6RE8{tW~L}bcZL++
z6uuUQ6uwl!X67iRROJ-@Y^Di}MRlo4DFP{iDMBg2z06T6sVWOpQ$!XrMyaK$r--Ji
zWoe{{Nit+<HZw6Yq^hTggGD5eL?po?Qb;1wU=bN45m~T^9FmB9s#=y7+*E~BwJdE2
zFNHCfK~wQ1D0cibnQyW9WEPj)Vs^}{yv5;PP?DLSmzbl;cuUqdu_(K=ASkssH8&|I
zwb&;=Ik6<aC_J$!FEcM)ljRn>OJYf4acapeo|OFZyqx^Rl=#x3oLej);aglDsU@jJ
z`5?9?<1InI{FGGxBG;7k)Zmi*qQvx6O~zX+C50)unvA!Ei&INV3*wVYi%arz<8$)U
z(^HFXv1O*`<rjgJaU>^}q^9Q=RcbQc;tj|zE^*CGN=-=txq<PPcyLK#NoKNpQDQ-c
zbADcNNl|Gs$OpIBQy`8@2Kf>hgVG_0U%ZEbfgzP4iZO*DiYY}ol{tlZ4r4n*8e<Ae
z3Tq2*6iX^=3R?<$3qvzw6k9553P%bjn9ZKTmBQV^5XF(glfv7=5XG6om%`t|5XA+G
z;TDD{Zcr?@FhudB@}}}hGNg#iVN4NC5o=+I;!ovE5ud}DB9S85!V)C_<w>PTx3EMB
zrt+o8%wbHCO_6J1i4ua!$)_l^utW(%d5S4YEi6$Y?F=joQKG>Nnku&hLW@(2z<vjX
zb8%{kpPMG*Ev}TL_{5ZyqSWHzWOk5Op%}zwV_;xl2IWEv1_p*2wi<?bh8orshEm2N
zr4q&kOf{?v8EY8gnM+uj8EP2fSxeZO85XeDFxD{DFxRjwWC5l9V1|{9ewrM&1mcUc
zGYjGiit>|Fi;FY!(r>XB7v^M^q#EC1PtMQFOHD4h#T=iKRK(A~z;H`Ez97H2Bt8`u
zckxN3nK>z`MYmXsOAAtqZgIrNXXa&=#K&v0-(o3F%}Kk(Qk0)xa*MO1C^4@%Ex#!D
z7H>gOYCM9=1>&Y<=9Hus6$yi^;74=MO2%89@$tzyiN(e7@hcgAIqPTS=celCWF#4;
z8yV>5R;J`9m*%GCl@#lzgVK*laY<rwc79oEQCd!Zxjx*f#rn{+q*qW`B+bCU0K&!K
z(vFdhnS+ssk%Os95<M#PU>-_lgoZze0<l5Z859RB3=9nE48aUVA`A=+MJfyo3@aIn
zWMFKtK#?lQXM9D8<?(5mIjQk^iMgr8AhBX5ka0{^@)-7_s>QG!6#HP?L6+7qWHA&8
zfMPy{aUo+c*fp9=e!m3u^uV6hhsJJdiY9ZBD99;H`Ng-`b25udQZkE*)Udc!1g-_i
z#hM`Rg3=xX6B{EJW0fj~tFh_Aa63xKrJ#fyR=04#3MPG!{ooK|#1d05RVemxFfcHH
zFeto=4>2$>lrW?)rZ6=#^-I+<f<ml@aUo+ZQ#eBk0}BHSLo+iYL!N*kL$Q_t10xtl
zGB7fvGjuYfFxN7dFfU+9VOhxN!Vnu%%UHs?fGvf!hH)WNElUkcQC<hb0`?M)8kPl|
zDQpWFn;91}GBSYGu!GgG*02`2m2lLs!sKh&YS?O+vsutYnG4;(<rk==;;3PAVQ6No
zWv^kcVNPKaXGmiTX3*sFtCEMsY*AvlepzBpW=ecvX=*X3%q-R`$j-UNRFw4c|NsC0
zi!>M*7^*ljQ{pvZQnWQ+g2KLv3zQQe0w{q63R+NV2Ia$I32?HmVTcu|WvpTBU?^iO
zVyR)A$W+J@3`ynSl&#5hi%HKQils0<GX<1W;RygN3rY7}HlUQ2oS$1@=gPpq@EMfk
zt1K|mPI_udd}$usfcVUmVm+Ij{N%)(Vmm#A5t>}LSn>-}^KP*fWF}|lq~2ly)mcTL
zv{|GDPTNJPx7dnP6N{2FZm}g66r|>*++xnmE4js*UX)*2pveKwoVOU$qIh6o2D0lG
zdv0Pyd}c~Tln|0|Jj9?VuFRAQ7@H$6wW0*1@)kFedm$b+1mz`A&b}p##Xm-%1S|*g
zKNAn51``V-3nL#R4<ies2os3pW8`2IU>0JmQXwXAP@)c04Dc{8Fo4ShAtq2dU%&_|
z4Hy|pm=-WCWT<7VVOYStkRgVtmZ_E*D#lX6uz)3nsfHnoIg7Q1rG`<0A&p6l0i1FX
zy4Y%%7BbedV(4MR(8FHCx`1OLLkVXMYcpdCb1#z%!vx0Irdl?rUd9@REUp^1Ebes1
z6qXvcES_}66xLqm80K2`T8>)I65a)THS8&Dk_<H*ATkYPFBeoV)U7G(INix#!<Hqm
zkg=8<RYwg&7C*X6Yq&w~bYX~n62nr<S<6$)Tgz9%n<ZGom%^6KGJ&zkC50o8DTOVa
zv4l5E2-J40<p;YR<YJZ-F1Wh|Y6YS0767@pMi4}T-7N%l1t{EwYj_tj)(S&KpgyT#
z$P&hIm+%6S6z&?Xg-nbLH6S)mjS!s8Tf+}xm+&qStr0E}%MxE8QNxfWnZj4YJewf}
z?EBdab6FNLfmGEnr*N1v)G|+CEL1D(NZ|@*&=mN^$)x}VRYC#q8eaia9O&ukRf)Kj
z=7HN?wyMQC3L%-fskW-cnvAztOY%z+b3!zkqBu&6Q{$6Ti%O!nOEODxQsa{nOOi8+
zz@>^Js8y!`uWuEgO#=lDSa#AZ0u|sO^HNGG3sU`VA@W~pZqhBT#JuA2)FM#h;udR8
zVp3|(Emn}r;-h%sg<E`PO1!4jE#{p3^jj>Md1?8#SV}T;Q;R@dl3T1PsY#{jpf(7o
zevOa6C7qU9lAHn8TpXWN32DD%rW6-(gUSl7%#zfilG1{l)Z$yrIjMO?k|0rDn2Y0+
z@>42r@xi%ZKZAk;T(%_V=OrhW+>(OW0&#tDd|G}H+{Rn%;O0_h$}P5%)V$*SB2ZzK
zmwt;s8P>>+2e(;namU9Or<Np^loZ9s-(o3FEK3C!Xt!7j5=$~}v1H|E=7Abc;5J%n
zN_<LYk){Zw7-K3iEV2hVnmIG2;ucehQ4~K~eQ=ApD8C%3P%F<bN&y9OF_Ktveo87>
z>=t8ukqZL@LzFPgKPCBa??iD!1i`@)#Ruh-7N-`)XQtd@El$oaN-c5&=@x*>A?$*Q
zz-=l5)diq3Ik6xEQg(uxPDS9dlO0+Mc!0`FA5eM8%*DjP%)`vVD8<ad%mt>Um^heO
z7}*$^7}@^vu!=EpFp9BoF!3;f)QB)DFtRXm{pVo@iE)9&co;!Yjj>7sqjZH@hf<<~
z(ln?-1ed6M;Fd!TLl$E*V+x}LLokCTliy2_v?f!LKd8*t1`$|VQefGMAb)^VGcZ<(
zVA!6MnwJcg1bLK!fdNzngIf9EHcJo#0|U4{z?j0+!coGofH8#`)B<T{tYs=;s$pto
zEV2W&OG+5BSQoI>Fl4dUFr~0aGAsm@oy;}N!3-rFDXbj~vl&v@<}x=kHZg))T$=2D
zD4qpb4{F$fYkmb-%`X9}`4=#xFo3krWm<^T0?=eEa%W&*0QFOfOA?Dpipw)gGBlZr
z3_wYd1JuX`<-wJVMe(3;fyZo-F#`j`RFIGfs9mCg5f<sGCD<zcDj`^d2O3qWDf*C7
zM*!4kQUKYmkeQc~TA@(It&o>rqL7wfnwO%d$qCL!;D!M>CW=7eS!51!K{SYn0THnv
zM{~MXB&UM=&6+IWtOa)ZEl!YMzzz0WETGZ|l2kK6+Q3PbIX$&x3MhIZJqZa$4n`>^
z7A7pwoSs^u$#{zcS|SvqL^&v`K~WA)Y@nb4M>Jy!(*mXx<`k9|h8jjtLvSJEY=#up
zxy+yndm&RTQ#wNm+aksqrg)YbP{WYjuPOyz!b1ysg<w}7SLYA~P~q&EqNCuHpHc}b
z+e>m%brk%|^HPgIap0Mv;1=ZX3s!|;Lb!)(kSkcqGeyDEPeJ1sqYg&u0VUdCP&j~E
zBIyh@46!`5j3o>;jLpbxMoq?B%*7=|u!O=`0SQjj^sxjKxS&9&asasvxo4~gDQjT`
zGD?Y#HQjJ&!o8!(QWOd@IRHcif=elwIq}dkOq2B%Q$fltzM|B`l=$Mp98k$uR9OU0
zB%t;%s7stvnwwXw$qY#%d7zX7uf~ewQ_|2q<jcUoun-h3pdy5Uk&Q{{9~(2vUltZF
z#ws0Byr;={iyi7r+(}0dTuy;nel?(^!UQIn=d#u?WHEsfObt^OGiYQag$*<s21_)(
z@GwG4G+G!Egi=XJVtTQHkEge*g1Tx!Wl2VUo~pWnqo0d{1}FiBR2HOqrYP7d7-~WU
zHGC3_OI%YjOY&jK&{m;JL{q^(NI?TRaoQ^U;?_h&<t;XlX`U$|D;aOGXXX{;Bqpb7
zvJ@48VzwAWfFp~uIJG1mTqwYzs|+Mp3L?NM4orY^u{f4=1q$qypqLW^mCKB5%uIh-
zn3?{vv9SDMVdZ12!YsUS7=k;xKvf$gr-LI3dvr0RFtjq&FlDhUU|q<-&ydC>$*_QJ
zAww-lj2$yF|6;@$v;=U}14WLGf|DMeXjO2;&<;uyo_>C=K?+{}o_^qRK*34D-%mlq
zNiQI=C^fGH9HowWkT`+HERvjOu!5g|h=O0Jj}IaWL8gK*q;P?_4ST|I0vilZIz~_n
zG@SIXB_1(I;&IYLO+H2CpjfY9U|<L-ssvSB9O$WqIVV54s036jaHW(M<Ya=zD^qT<
zB<B}Y7FC0EK(jn34Z+%~STg)AQAjF)G=RVb6euBW0wp9+(a6Ba#=`W6g@x%48!O9i
z7B)FXKE^6F90dWIA)1W01mHy$v>w4oT%eIra5bWYRu+J(4QMF<>Is4pP%wigvtLyX
zMw~;F45YF^3P~M>U{E6uLp`Bl0V$mnIf4QcJ+nr!r4;2C<R!v#Y9lDR>OllJlY$9Q
zPAv*yU|`q_3QTBu3eS-oj8)nMf)TlCjuLL5$_pNDVR*tViy2fMr?7%bPjJR%17}=z
zc-S)6Fr%qq0gbw*Gl9nCG&%gL_G5${G^`<6fa2`woSzFC3`OM7{9?5HiQxcZvLGmA
z!P&e@5t{!X`5m4u4G~#8iWBMx&lE^zOXOr^U_h}J&CkwgMG`zYVJVJiR}>+c4E2D5
zGpwA0g`zXKF!D@6_?Ear%vldJ1ThLWco1o_qLohI5~v8&zJfJQLB&v%L^33{i{tb2
zKy4Ms0APGdT9g>F46GQzXr0^=gA`V9#YvUmrqD4^29ttRd+?HqkBN<y={Fl2%P$so
zKBg*lkcY6e&Ctz4$!MTH20WvIns_A)&7hi$VF4p}AeOm?VF44E#R3|x0<%~_bsuQr
zqJ}kvv6mU#_0nXH;)Ioi#Zl}~58mQ{S+2=iBmgQrYCuE-h-d<})Hvb&ZcR2w0}dRF
zg5XYgJX|ur=oV99@h#@Wyy9C-$;IeRx(JX_pb&yI`9VYBT%h4_1&k1Zn+WOVYqI+J
z`T4nNiWPw>{##t}@wxdar8%kb@wa&5;|og@b3mP@`1qm%kc&a{6Sp`Dit-Cmi%Kes
z)Inn4UZxI+1sd@vN(HgN-OpSQ3)G7%vIen0t<EBZ=Rt*L5hw(Tz;5S-&Zi`&=H$f3
zBY6*0Wq^C&MW8khC;=3MI)faHB3w+2T+B?2T&x_-9AX@v9LyZNLL7XYeB6ATd?H}g
zpmCSn#LT?-_$m%qXRS(DPfs6ZK1MGk$xV|9JkwL8$iTpmX~x39P{kJr9lTb^OUcZ&
zy2SyD+bSMM@SqS>oDCAuRlKnL2~o}tjrJ<eki>MTQs$Dx^eP@-aQH%n*dU=@#R;>`
zs)`+Ayp^WTEq2h9C3yZTzgQ16>;maInt_57G|g5d28wpJlFEY2y!0Y)P=k7HMWDK)
z$O|l2keHW}SPaR!pnhf)b5gNM5h%Ui5-G_qO3sK+PtDIwEh)-O)&ut?LD``Q)bA?-
z%~2F_f?6L6*i`CefahW3!MZ@n4l;bjQ4X2LExyHzGVfXh8VbI}2borbsfiK;HyPjq
zbg;QQ@XTJ6WHERO7LrgASs^n&51f;3v3UBq`4@o(WWccvNx@7h#YNzhmjm)YYe5lc
zn4PU6J|#1`<Q8jjMq)uKcr*z-)`OJJ;MMsp4jV|M*nx&yifur%HH;ikD8k6W1Y+|r
iN`QxMK!ZDcOgxM{Ogu~i^$ea|oP3O2j6!h8#|Qw!({W}1

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7b5680a08ba3da9769d0355f247721d87c640624
GIT binary patch
literal 3325
zcmYe~<>g{vU|^_O%#xzW#lY|w#6iaF3=9ko3=9m#T8s<~DGVu$ISf${nlXwI%x8*X
zN?}N0%3;oBiDF@d$gxJTrZA*1=dk6nN3ny&SaLXWIiolk8B#e@SW_9YxSE-wxZN32
z*izVA7*g0%*_)Z8cv5*&II@{0Fczt$@-;I=@uxB_5J=%%$mqfls}Uub%CtZzg)3EX
zAybras&G+ND*pnJRMAx71!5`O3mKak7cw$36jr5*r}8&5rZA?8rb>WWJR%IujIB(P
z3{jFPyeWJw45|Ds49$#DQmNbvq*M4)B~m4u8D}%hWdi93iwUGkq)H))@uUc*N~Q?G
zRHg{eWo~AUl1T-LOQndwR71sOQy7C8G(}&6LeWo?@fM3uW^suo<1K#AJm1vZ{Gv*i
z#FE6~RFIHxacW6vL40y)aY=q|d`^CPdTLRXU}m0fE?A*%3Pfcx$S7pY3gr}MF)%Qs
zGJql?iYY}fl{u9ql~s}<g=r2`J3|^{3Udle3vU!#DtiiR3R?@v(;TVnDeNg6U^Zte
zdkSX?7nseJ%AUfV!UJY=LxL%aCxt&npoJldx1E87A&M`UK~w0KL}+npk$X{MK?c+p
zFi-huGTq`$%gjlQ&r8frjn6MFxy4ghnp#|vnV%P*nNpk#N=Q&=Gl1A^3=9m+poHZE
zN`6c=4Dk##%ry-0j3rFX3=5cR7;9K+SZml8vM@4~uq<G$VaQ@jVU%Q8$QT1s%U;6}
z&r!pW#lC<ug(-!(g<&CM4MRM83PUi1CW~JYBLf3NGKhfL2x2oaFfaszY~%ph$gqH6
zAww|3N=8r|v8AMzBqnEQGTvgWSjkYt&A`C$%U3@mKQ~oBCnL!)-N-;cw=yL^ximL5
zucTN%9TWg2#U+W!+4*IuMQJ(t<@%|)NvSC*nR)5O`p{^L&&-R5M5JCp<t+}IoW$Iu
zltepW1_p-DAa_+6;|ff@`1s7c%#!$cJ)4~T<iwm}J3W|gO|Dyf@u=a<8;=~&w^)ly
z3sQ@2alp*D#afbIl$>#kB_}^I<raTIQGRl2adB!&d_iJKM)55ckZW(T73X9orxxF0
zO)MzLsnle@#ZsJ_lXi=xC_lfX2ozOCpg6n5SyGgkSDcn#lzWS}peQvS!Q}#R(=u~P
zQj1nH-r|gpPtHj!E{=}}hh-5D0|P^m5Ca1PA1M66B@QDSBO9X-qXc7>BEew5h)Na)
z1_o!4Q34DM3?&R%j0>1*7_yjC7#A`HGiWmTMF}Ik36F)9EVr2Qi*K=j{C$fhD?c-@
zNQ!}hVI^acB*=N7NJqF9<Ud$IrDPU~f-D0iAO@x?HJt86SelYqM6#oAal-W%DS%w0
z0wU0y#a@(H9-op4c5*S)K?*n>1XYaNDIyFE3@MB$Of3v03=0@jm_cQ2GpPJx@k7xB
zvI!LE;QS-Oz`&5sP{R-_P|H}un8HxTSj1ApIFYH4C73~zQIn|(oHihdL9d{sN&u8L
z;7LWVpadMuE17OF=@}G(5(AhIAwVhR7MD#<W^qY!er|!CJSYr6siw*ruP@<Qq8K^P
zz)blisHX?cmiq7@)?_MDWME*p#SXG6C9|jqTWIrv0uW)eEK({^C(=16!3ZjbKp32M
zA;nM@LpozHLlKCh$yme>@-8-)h`=3=$o<7255N+EE>8Dg(}~+vpjZYKqhJ?7;sl#}
zI8yS<^FSqvDm+Z^<z$#@6njDC3aD5ChsF|cqN-)AVaQ@|VTkprWh!A@z?8yR!?cj8
zmbr$x$fSh1hIs)?3KOWVuVKhyEn_S)DdZ|)Tfknzk;S=ytA-(q8>D(6V=YS!;{u)<
z7Ep~|$W_CzfVYMbRB_a@)UuZF)i5?Q&Spqqu3^aHPhpW{n9VR3#1=?lv0*4-TObIk
zc+!~QW(t7REM#2B1Zu|wGib8<{o=?>iPwlp(bjwkip*b(y5I^&Q>aK6l%kl5Qg5*p
zrzV05UFOWZl3T3lMfs%#MWC{(2vqtPfpQ%<^%W_B;#(O+sDre!B^DH<=B3<XPtMQD
zNi0dd#hjCxSEL7$2bDm#Sc((NQZ?DYx&Ia?sN{|ZCEr_&X}5%sDuH;2?L}fBjp85z
zy%d3_2@O!>fmAUtu`qEl@-VV6iZC)Ui81mp@iEFVR^cvLpay8N`1$$yY4U@r8*s%R
ze~T+VJ~uz5G$%Da{uWPsd|_!~4pfG@xTL5ERFV{#g4_hE4{mW36y+DB7L`;Mff7;?
zC=|e<2=TN&NSz%hWO$)9N^)vW4y;xIdrA(Z1XOAifm%hpFb{DsO0hFBaxo)g4rZ_#
zkXv&TGxOr(i?l(S4M2n;DEXy<!83c5R0?u1C{ut_DQeMIf>QK>Gm9o8xRT5SH+qUd
z;R_aq5MXC&F)%Q=f~*9UYDFMkl|)EUVrE`uUb+G-5*3p3OY=$;s>DM;O(-l99Iyhd
zime2sS+|N!x3Cyg+On3U7MJL1%H3iIH7GOl(jh6$7!-<ZC6xsr5s;t2N%58lxN#7l
zo|>PVT2hpmtOrU|MWBY+Enbv%X%Q$4Zn1d!x%uB>&dE;)2W62x$W-Ry)Dm#}phySg
zFB=eH0wQd|vWY1vkO(&eaalkKw#X7B?Z&{s5XAv<E~tfgizPEJE&moT*gNqcXB4NF
z+~NUCLwS%;1V=lV0EOo*4jV{vvjgSDVo->4FoGZt69*#?69@|EaWS$n@qv@K7_$r`
MACm~92$KXO0PwpxBme*a

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/post_embedding_builder.cpython-39.pyc b/embeddings/__pycache__/post_embedding_builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c28eb16ed26e7813cbf1dd1fd1857bcb9fed14ba
GIT binary patch
literal 10916
zcmYe~<>g{vU|?Y6_?OJ0#lY|w#6iaF3=9ko3=9m#S&R$}DGVu$ISf${nlXwI%x8*X
z0@KV<%wU=&iUmxwMzN+aq%h^M<+4YyGlKLm=Wyh5Msb4KEIC}cTv1%P+)><MK5Gt7
zE^ibsn9Z2O7sUrQgFlKNObg@+MhSxD*m8t&g`<RXMWRG<MWaM>#iGP=#iPV?C88v9
zC8H$4YS?q6a;2lB!EBBknOxZ@Sw@Cb=@iaXrYyN;rYQMT_7tvkh7|5ajLl3@3aRWV
zJSn^>eCbThOi_yN3@Q960xb+F0;z&oO3lns%Bd<Tg4s+H7>m+Um75u&R8v_Nuq<SV
zQcG1&5o%_1VTd(~(nwWZpqV0^BC<d$MKqmpA!9nD2*X0gEbS<rRMiE#DZ(jYa1p%;
zjD=!QGxSppQp8gvQuR~SBpA||#2A_x85vR-gBdg>UxH%6Pm}Q$S6*UnYD!6IK~AbB
z<1H4S%;FMFrdtAjsTCz5iRrGnNvSC*nR)5A#L-1UiV`#PQj1nH-r{viO)M?ROe@U^
z&Mz&{WCsZpm4xJHr{-l=r54@dbV@BM@y$<3&ABBIl%JGZRFa60<bjGn<ar%SOY#wt
zoFFblh6}`W&d*EBOxI+*C0v|ZQd$t7Tv}X`pBtZ(pPrstq{(=TGc+$VA~`=Lwdj^u
zL4I*bd@9WS@kynbIUsR1&)kCiq7qH!TWq=cDWy57w^%?SaEmoDCo{457KdADUUErh
ze%>wKG$<<`EYDw%Sd^HXT9R4><`*Y}0uULqK{>??j0_B^44~v2#gxJr#hjwj&XC5K
z!kog=!WqSq%9_e1$&ktpr8%-VQ&{J)q_Cy1x3ENUrLsZ9Q#j@@rEsQjwXj5Sr>Lg#
zq_V?hc;>Kxl2Z$76mN=JDo-kFDxV}nDr+jgBtsfg3jZ9Y6oC}M7S<>Mh)Jm+8-(UC
zfs#)POOzmzsOTJ~6tNWX7M3U>BvFYuOevBnQY|b|!YR@zGA#^IBB^{SvU3<y<Wl5Y
zSfWIsJcSg+7M3WnRDQ6WQi^g5OO$v!0}DfxL@<M<`Yqmo{NfUL`l$-wQczG(2+1!>
z&QO2^yh2)jkwQsPVqS4teo<~_Ub;eXNn&z#epzZ!T26ks0w@BD6*BWm@)Z&lO28UR
zQuB)Qi}b)oxg|3}LmWhb*enbT49uXyq=bQip@gBCVFBYp21bSwrUlF?j0+hV8EP17
z7~)w<SQoIRFf9bJ*=rcG*t0mYIBOW<xk|Vf@T4$<#d%BkvUor-mcrW0RKpO@Um}pg
z)(p}iSR%B5CxsoXO1MM>&K6B!2xida@T-zW@wP%{UU5lcUP)$RNotB>6klpyd~$wK
zYJ7QWQhaf4l^`fH=q4o=r|OpGB_|fAri5rRMv3H?mK2nh#Ajrtq@?D>7ndZKq!w36
z<Yp!p<rn9tmFSn|WmbSwIaq6zI6?*_3O7?z>=tWrX+dhyEsps3%)HE!`1o7QIr-_g
zSTgg{@^5j(Cnc67XQZawVl6I6Os>4el9Qj9a*H`WH7|<8B{eOvG^eCEimSLJzaYLm
zzbK_RiYp$(0kexZ7#J9C@uwB#=f)QlrIvsTgOpqR@gUpdOJI4wh!doY8zPzuDgufG
zKq6f6a6c6BfCPl&VNOTV%nOp?N63H;&}6#BQk<HT7R3&A{z}GMobmC=If=!^@$oAe
zeg)}g<maa9=VT-qrW+aP=T@fVCzs}?=9Ludrxzs_WSA5qCo_F`5mBrUN^zJ)hF(Es
zku(DXgCZztiGUItBNrnVBO4PBBM1sGaxwBSR*9j+yB;jSHCb+PWt8ORK*IVKM`>{?
zI52K;mt>aYq(X(+ic(8Ti}I2|DHvn{2!p~26nfwk-2|#08B>^A7-|?@7@8StSxOjc
zSehA27;2cC8MB#+mAn|57~2`z8Pk|jm{XWqI7%36SZbJ?nZT-9!Fd3rhq0M)HbV;Q
zTxL*yC;?YBwQL=XMQb`3vRI3%Iv5tPrLckRVFl@2$XLr>xFwyTmc546g&|h0mIJJU
zeF1w72dEla$XLr+!@hu{gmVE`4d+5|E-P%x<0#=y;b>-TVuaej1-73vg{y|Eh7D2`
zf$RsDi`*dB*06wVtL0AN0o%<D=J7yvvw_?;fw4%UgJA)vDg}iCSp7oATHZXTbcR}9
zkp5VUT0XD}z6I<l{9tqV7VxGBECh!b)V2IIyinKjLtHz7u}}eO56qQ<sIIJGui>rX
zNn;LX&=jir3@&>>MVSM*LQnuju9bpEh_8=AaY+%V6w6OjD9K1wNXk#CRLD<LNQ4w?
z{>a4|qy$4U1)LbH6hc600-~v~G_@F14}pt84E1R_i6teec~%O(sYU6j3W+&63Lxiz
z(|55#Nq%~2Nk(cB!mW_>Wu*X0qzdp9uTYd)P?TDnnpcvVqL5gOR?Nb6Co_W5BNT%Q
zT2=-IhG0-`5My9qNN1>Fh~=vVB}K**hGwRTOoc4L44RC$SaK?IbBaLKq9)TVCOv~&
zETFI|(q>>_fDk&M?9FA9lUZDnoS$1@XU4$5@EK%Hl_toY@VY}UEx)Kdu_#5)CMQ2R
zF{jv051|LeZJ?|IYEcS<+@=6?n?wv#En_WH2g3q}8paMr9)=pGh0G9Fg59CXc#AJR
zwIm)9fSSxjMj*Gb6z7)~++qbs&Pt{t4Uo(AKm^ziAm<mEFfcG!f`m*N7#OO!1B&vK
zQ;Um1$)Ji29CgJgF2osnDPY$vWQ4d&lkpd$0w@Tx@-y=^nINuVDkuiK7ec^ocZAtq
zWCe19i+fIfSSlhlAjNiGYHEr?Nxni-szOm}T4HiZei1~VJijPgp**uBLjh8x7XK1~
zHG^zbi**!2GILXHRg0^H0}_i8!D$R(Du(w!5e|+3DOdyu)iQQ4WHHo$8}z-f1i}=e
z$#_c)6!4kg7FlsT$OZ95sp+W|nvAzNKr)~@VI{=-Mdl0)3@aHSk$@Whey|Y0a0{qB
z1t(Ncu?<eBJhhA<cSth80=fv4xWPdRwdfXO1(*vV;KA&V(-Kfk2)0BG=3%*7Mo@gF
zKnv^zjG$%(s7QmyFUu`PE@EA~fMFrjwO}W$goaH?YDGzr6UbHI&;S!~*M@+^!Aa7D
z=&Xd4t8TG?LSB=p$QqOlIMOonQsNVHa*9A*HLxQf1g3RGE({C|WO}T~7364KK8gTI
zf_zlu4oc=i&iN^+POd>Ah?J<wcZ)eEHSZR4Rc1kv49MZEDXB@N>9<%)GILXl<Uk^V
zCHe8-77Dx!gm{rHv7jI|FXa|1xSe^6y|}P6HMJ`Bmat!biDzy>PAaJVkecFJRFq$I
zOSB{(-h4qc7m7R>7#K7KAPH0#<XUi1q0hj;5XA~ANQ$Eb;YAFj-cK#Q#Z+P##R(Dy
z8B|<k15(Wa<KJQ`F^b{^MMZpCYGO%gQEKrmc1Yv7IEoW2o|>Cf49*%+956m4Zoq{S
zD1vUW!-`K(sR$}#894ZuL>Q%*#F#l4Ss2+E#h6$aIT(c)MVL95I2gs4gc$jlIT%4K
zkhl^^oKcLaN)&5K0yW3LNznil^Put#RDy#rDD#1^Gss?0X<P#?<3Nq_g^aaK(A)uO
zpfJ~fs#oT!x$sH{TD^jdwo*WpJ0QLSsMU#9tAc7<gx(@>bIM8q)SCrWr{MNzu|j!9
zW^#r?O1?r~eu+Y6UUE)pN~%I}NqzyiNvoq!0P69TCW5Qo{33<C(%hufqGCkli;=KF
z6%IJ-tH83p45)m`Vkl;*Va#F#lT2U|rNUvk#hzIVZc9gTWERJRt&QTy$;?iT&nwL>
z3I`=5NU3*=1>`hHjH2f7bWjQd<%cRA)OaZXRqEhc2wQ!r$y5Yt$KGPi$uCbW(qw@o
zSCEajm{anTZ?Qpw3eCn_TnJMWLD2~+<QZ7`82K2hByd^pr^y0NnMJW6r&@pra2f%n
z2~hd~VNgo~gh2sWoW;PvP{ULt2W}RE%1(|N#sv(pva^;YPb7r}+^j<ptzk}KEoCf<
zsbMM#s9{;a*uk)XsfHQSxXa@KiPSKI#!MI)pbcAa(=nJqlfCLLBE^8RT&@Brn?f5&
zpacbJ)|D553WpSh%shoe1(cx<h5V$f)Z`LG0)Zu3P&3X-!81)Eu`Dq&Cow4}RYxH)
zB}D<;l7O_dKo)|lQY76tk}as8oSK-NfvD6GdaOX^<SHPd#|rKzjMNH>G%PhUs96Nb
zl#n_R9QT?`kZ5xQMH{$G2DdX%(_j%u5}XDNP!j>D?glvw?7?De>F^d~0fzfP2?Ly)
zkeijDB(RVn9ojGgyA6Lu0qXo0fgI(BVgV?H!7V^;nt(b~u^hEbHB1wk3wc0oSBNDb
zufiLGAkTm*hwRimglrM0cm=lrYhVGP$y@}g28zH<5>UF=WQF7~Q0oN}qM(cr#R@VI
z+MELiEZ8Ve;S9<I#h|_)13L?&6q5)OsEih3;$f<i!57AwOpv^Wk;6c(6A%WM!d(mu
z3@M=8!<51d>PjwPOksg_k7`+}_(0vX8kU8OwXER`c~UG4B`lRrj0_-{&XCTK#>B=D
z$q>w70LrBiC9E|J3)pHH7cw?8*0SaCf%)t;tPp+;b5ThNM+z&b69|>#gvznku+^}q
zF$FVdvQ;&~Q!6y<P@Dr{dB{0GHz_kOH3g|x2C5(xit{oH3Q|jo;TF@b$<ho<6;)c8
zQ3rAiwu}L3-q;7Frn^=sWWt&rkU9xe5Wq$y6+oks#gLi|rYbc>p)$2ZK~F(XlMCFg
z0Vl>=%!$P%w?HYbIQ15HaAjUeVg;lj!vgAh-4X@mR#+RnEHe?(1ivMWE}U3gQWOYk
zM2Uc8VD5+q=}0Ut(d0l%=ysrHj0CbW<ktNy9+2nbp)qlbyBO@G_>!X3)S^(3PQ=h_
zN@^N3mw>V(B*WP=Ffi1CDj^AQ4I;(J#K_0U!^p?T!YIJR$Hc<O#l(j(9s_a<$W~3J
zq8L!Lfb$6x$N(q?l@Fkr2%Jwq1zrh54P!H73S%}?kwOXM0;U><Ean9)DNLZcWFccP
z1B4Gs|Fujdj9IKTpcV~tFB3$cG-x=GrK$v;%%JHTQoSNtCWwL?+!f1C%|oP22UuE$
zjG96QKS9lz#In?)#Pn1JP}=~}{)Jlw?$X?1Sjl*cIVCe!ldZ@T<YRCLCIA$?{PF2I
z`ALa6@nxx?K45W_5GZuYQj?40lPV#@Z^gHmi%WAgnIUls3b81DkZqt=Dr7<cJz8$b
zl;nfVElWjoq!FqogW?NR3^B0sFp7Z2%h0j}KFyj;w^%{#zTzT`SOZmQAY;L?Rsd>0
zKsuU?pw=a50JoPF6gvx;7BciZFoNSOjUkhvhG_|7A7d?Z4RaPtElUk(Mv0-7DFrm6
z#L&u=#w^KD%UZ)4&cMWw!ob47!qCjj$Pmn6$WWYQ04@y~86p`#^FN?6%rE2?Q$fi~
z1||juO_p2i1&QGKk|?f{BG8a7s4W@A2Z_jdNDLO=;>t~|h|fvQOE1X)$CV~)kuofb
zZn5X($EO!1rWAqNlu?2?iN&ZR!J2H4NXZ9fMb6Z`<b2TlOz|!DeDJ7oaTFI=OHN{C
zYSAs`<Q(*fD)MDuV3-Sv5759X1EUBd7oz~E7r<C0k6KJ&^p!N3pjj6>On@05pwSml
zE&}DI;tU1`h7QIohAdE*n=yqcg}IlxmZ^j(g(Zcxg`o!A@^5DBWN3#BTe79Fws1g1
zYne)z7qHYYEo7`^?qFQN0BX<0Fx9fuvX(Gsv1YMlvDdJsu%~dOaP~6SvemE!GiY-C
za)D<6_tX-F<osMvpAC_lpe-d(A%JKjSwUh3Ay8Dvz`#(&X9Et`)FR#dyqroqO{St^
zP-wG&2B1+|6`&FX-l{+zWPp?c$PFq+NHzo&E!dh6Yd{GA)NH9T$A|^kSUDm;VXHYX
z90rOEaC;9laFq@krc(g*P9RZ|#gxw2!BD~sD&!;>L>R=8#~3tOic~=!W=jT7xJB_4
zB&I+LhWN}Bgg;g?MG5BQ=cT6>m%wF<A*mBlT$?a3Fsy}n?j<O|UV?^PG?|LxK%zOI
z&JG8}Iu3B@m|2y2i#a#3q6k!C-C~1enp;eHdAGP=l}=*LEzZ3BqTIxs%&JsPc8HY~
zpt6k<Y(#t}Xtts_zqBYh70ea{3l}2-usDhzloF5ygyANpfTtWla*&t;H}AkLv>*ls
zhNYmq!w0H1*n~j6W=1YXK1LD7DrF+#4Q0*>l>gv`R}CX*&T0W8q(Vt&sAW!PsAT~+
z$e6OgO|mMZ64n}4P-ClxEsL#$DT_UusmQ5@WdTPGTNWp1JUx%6h8e*NXDF0oVJKnB
z;;PI5HQ_*oR-s%ZM3lR-1}-)QD#lYe0WJnsXTVS>SHqCSyMQT$aRDEwVF;QJKoVgF
zb+c+YO89Hon;A>^YdD%2#TinVvRR6LAd6+Q6uqip2NlPyOp*+>?1ejOIN&^v!W|_7
z3j|YGAv$I=q_9Cun$0km1>`G`eOW^3j5X}ypaPsXOE{fToS}vT$`b+cYB@ow7Kno8
zAZod47#4_?h%b<+;aUiqcV$jtZe>bil4MxOn8Gf}P|IDzoh1n_JQ%YiQ`l2DdqGp;
zTt$_j^avV6&HT&Az~J}t|NsC0tGNAq6u=YG#a0Scd<vjpCpaHmco%`%B1PGtWGr=y
zDKq~TM=*FEz|+5|o&hvRmY$lIT9jCl3dthHw>XMY6H{_B^FTAtsTC!+SV}4jQg5+4
z`}_L3`h|eI4!77Ka|yS&z*a+AYrMfF`30UOAhSU8(*mA(DXDoS;Fcj+K@n*D5z_1f
zSNNbIup)18rbx~yhS(m(S)5t`W2WZAL`9SH^U6|-pt%?%3d%J_>YxAt<-(#ykWW~0
zQxo$vc_6tBGzxc%IjJ=578@j=-Qs{1xlzzen4X$flmNCrF9%|0aTGT=*Fkb4xFr_F
z1>t}!yv3ZFn*_;n@bcr9SV=x4JED%;Z3g8+P}#!3&c(>W$il?K%*4n6Y6mh2F!F&}
zARZ5s5R(wI2%`Wa9}^#w9FrO&50e68l?p~)1-lE?3Qb1L3Jz4cfy!0T^nWp^HHEV~
z#gxKY%Ur@#!d$|V#ahE$!&JhO#g@X>%iPQeUe+O0%TmG)X?ip>hcPfR6k34iklFoc
z*CTlbicoBo45$H!(FMR($w0a%T5&POS{n9Npn<p=h(67jVr`9>a%~NJtC(UfotR>+
zm~w4RElo|vTdZk0iRs0f9FVjL?jNUs(i3xWYRN68qSRaLpaJH@oSa*nzM(!Lo<5#_
zu0=wi429fz2DK)l*aI9xLR^FVia_-mT1qKOWnf@<3Q88BLY0A0iiw3u4x?Dc>Qj_r
z86}F-z)J}r#j*=SY#6BhQo~fjkj320RHOhZrt@S<n7|?|%}hmHHB4Ec<pqqOmP;N}
zA&AfJ!VoJ{%TmHz!(79X#j$`5v>2g=C5sap?M{VUDWKjk^DlFF3kKR-04GXlKMB!B
zf%OaEs&6ss`Dt<$fkx(T31otns3d1(=A^{u<)@?;-{Q(FhO9zJP6dbREiT9s7`LMQ
zTyQvo=KpgtOK!1(yOc$sX3Q<N#9|N+Hhhr@iZ2$Bc@Q0p7{ef-fcgjueDK%>E2u}t
z2F{I4Ok9jrim3SwJrWX&OHiT!)T9Doa4rK)I+QTLqrd<q3b;TGeNbaEizN#@uvg1m
z!dk<W#Ri)D2InfS6wuHzOVJ$A_%x{H0jjw{7}OL2wNt@EPN3ulZX$7knn+nFjUmQc
z%!x(mIOke87#SGA15Q<Dn7)LySYX{`Y<bF0lN}OwEaizg*+ug~Va*MjfOAPryTy`N
zlwMo}8WjUa4s6sG6e&eZK&rVw!%5)9Nf@nt&^R$8BLhP*I8%V*MToIViLgsG8QnA)
z{WO`}+%zRY3rIl8AwK>VSA0BV1|mNG7EgS9VQFFxRE9l1J|#anJ|4VSq6ieSx0s7d
zif*wa<>%*s8*|`lq$mtj2!Y1Nia-Tv5opj0oGu|A1oiHV@<2*J{gfh5gR5v7hzrWy
zMW9qx1oF}?Ug(mK<kXy;_;{qEKog`G+*p%jU|;}Mg2kZWJ`P3}7Cv4%V-Zd%5jjyY
zE-?-<J`N*39zKC0kUEeLa}zW3;^QIpc`e8k&}3GXK4{b>6Eyl;l$e*Es$p!T3CW7e
zkiLV0nL>V%LUDd>YDq?BUOHEmL`X(vu|g(zIz6?xL;*a%o|l}eS7Zx{D^B>-lcpg!
z+Kbvi!t9_$Y?*oKMKeJ*`+x{=^MN}#KPM*@yd<r-2$W=Rv6WO7fE0p4xCoR>ia;gS
zE#{<R6G({4gEWJx?OPn+=`6jxyrL#hJ<J1NZIxP73|_$lUY}D0TBBLS4blZ_k`{rM
z-4uZuxVLyw7T6W>gACz>^+t*yOL<s4{oMS)^KFP&yv61VURwta+M?AU|1lM$7J+iw
zEmm+a-eSyz#4UJ+k+nD@u^{ypV<Nb7dW$is2sD9ti?yIAGp__GZV{~l4jV`+v;*bn
yV$hf>XvPYLc^G+^1ekajL3|!Y70^l@CLU%!Mm}btdWJ&K+8t2aK#CDG0|o#bD9<SX

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c148ef5bc25a569593d938dd5ee030df7d33280
GIT binary patch
literal 7007
zcmYe~<>g{vU|?Y6_?Ns~fq~&Mh=Yt-7#J8F7#J9e_b@Opq%fo~<}l<kMlt3xMKR?v
zM=>*k#F%neqF7QGQkZjCbJ?QUz+x;p?719K9AGwU4reY`6c-~yD$4@yg$!VtXCcD^
z-c*(>zJ-hsek%V$#wdXlh7_h8!6+dxEgU7B!jQt2Ba$l`B?>l)Jx458JW4!QB1!_x
zXUdVxm5P#PWN>Fl;Yi_ZVMyUj6>VmYl5uB9;Y#6dVMyU_W{Q$cVG3r@<ar75mnQQq
z7N5-Gl3UD<d6l;~{0mAl^YapO{4^PF3Hs%yr1}@Rrlh9^m*f{Erl)E$-V*c8Pbtkw
zb<IsmO-aejOAjeZ%*;zI(qz0P?w49o5|WsXMT*NKwIsDD-zBjmQIqKwn`dr8eo@IS
z-hlk#61eJQkZX`JE0j~L!@$6h$`Hkv!Vtxj!kWsQ%96^O$|lK>!Ze4uogs}eg*k<#
zg*S@5oq>fRiX)golkJwEQ({SSMrsPeUO!FdTg=(1mA6>S5_3vZZ?P7a<QJtTgQ62^
z4+9$m0|SW94~kwX1_p*2#uUafh9aI6hG2#ghGvEZj5UlkObbDQ#{@Eur?51&7!+Rd
zskuqFxDxY<%TtTMY}TB_q|_WurdzDVr3I-)x47fulM_qgQ!;bo<2BiCu@tA~q!sZn
zFfbJHf(Ska28LTKiACwfx7f1FLG((-Tb%Ln$vKI|#qseg8GeQ8XXNLm>gQx68KxT<
z=;v0Z<R_QrrskCt>!%kb7G#(dmn0@<=a;1xrRC(8>!-q^r&zzZB(WqjIUcMkJ~=<H
zxTL5w8RTfag32Oc1_lOkkat0TXJFxAWMQn5#Rxk+WY;tLX>#7;h>y=p%uS7tzr__F
zp9@YA@$t8K;^PZT6LX+4%*7=|Ma&@YfdaJ%6rEuEZt+4rmYkZC6CYnB2~q+kKrSf)
z1$+@V0|SF7$Xy%^3=AC14FA~}+5Uq?F_H<$25^R)f}BhkQ<z#9qL@=zvRG5uve;9Y
z=P;$Pq_DQIL~+2y*yb>$u%~deutagf#W?3MfintA6jusQ3U3QT6n6?=3V#bj6i+H|
zDqkwEBtwe89Htb(6rmQDDE?HwQ~^nb6yZ6HDIzJNEi6%jsX{4Ya~M;^QzTkgqJ&c<
zgBdiXZixqjBf%XMQ_e_Hkjw~6d(fl@iUwzpzsndHz)7x_v4k;;DTOhG36#oHSbJG&
znM#<mm{Qn4PEX<NWv*q;WT;_?X91I}C2Y+g9(xK)4MRLfI*82)Cb_^QH<;uBle}P(
zFNLdyA)Y^lyM`fN0F=-*dHkx>_4M>}^HWlDiuFqp({)pGlXT7FO>;AIbWIIXbj^&-
zOpSGo^$JQ}8Zj_1WG?%`!0-|j4Vuiic;P7zl-O@^mE|WVCdC(LR;AwJOU}<LNv$Y}
z&q>WoFUhFVKv)7vL123fb2D=)Om&S7Qglr$Ee#Dpw#b2O`F@CjL6fOS8I)R77#J8r
zUV;j0O~E1%Rm2X;L0mcci7BAun3_`r3dvjC@!5I#<$3WXiRs0+_@Eq!WyRpsbc;1U
zC9x#&7H@o7W>IlTJXEO|NJo50emo+Br=+JA-x7g|!?HS9R1_wfoSzHIlJTivVM&ln
zN^?@Nm?MrPg<_Nhk_4)0objc_sqsmvMJ2Zc;|ubOOX8FBOY=(Nb24)?OEg)EL_h%`
z29`iXK~iaEPD*OgE#Z>HbWphjDY8L!q~_cbg9HS+I8uHDWkpa{&&<m#iH}zVrAts$
zFfg$(3NZ>WvN5tTaxk(nfoLen$i`TufFrT%#m7So(qy{D2@BZbTWm$CC8b4q$xKLD
z0aOBmA`q0>i?4yRLJeaxV=<&8W?aAoDw!o2Y8bMZYZz0QBpDVm$1vA2)iT$zq%haA
z)-a|sq_8Yvlw?@IvXH@rA$C#>OD$V1do4#TXAQ#w)*6-?R!N41jI~@f3=7z5I2SUe
zfXrfVWs+p5<*wm&VThHg<*8xF;#k00!;r;Q!z0O1!<NPbs+tQ6YPi7iyfqA2++ca0
z8eSy%2_?J>_-dFyHq<bKNRYl-z8cmPR&$10)*8Mr21bTLJ8)KG^Sj03l%G<0i!~&(
zBqy~9R2-~iF0uxtaT`#CvV<h27rz7vYx3V>&PmNHQUJ-Bf(T=f0<O%G)S{Bof*erl
zEGWuPPAx75#cyd&D!7C$zQtFNSX7(}PNF%P#U(|c+O$X)q?awRpdd9brAQNGngNK=
z0tvI0<QFAp++s`4&r41$(c}jwts+a1lt6iYQ3@!*#3zC4*Wz0O$@wX%C?dkpWCgAT
zV9L0Q^2;F!Hnr#$J2;hQrrcsl0{guL<o8=F;7rO6(OrCt1C(Sz4!*@)l9+yr4I+Mv
z1FSbSHwitN3c@@OPg&-m^a=_t21X7>7A6)(f&W}AEQ~Tte2gMY9E=iREXK&eSfz<8
z#bTJL$#jbyl3|OBtQZ&=aHrZE7^yZ3Bh{ub_af3Q3pCxb(kR^)5lFX_aHd-}q;%^5
zPPgoSh%~yA@fJ5Y1|ewzoQ8{RL8-}}fq_AjzX()0qb1uSN07V|hyc}gkTj_W;<|tc
zaDp@hu|Ns3$Q8tLgBS)X=88f=0`4Hf14MX&2rp3jWXvq`0dagmgdd3T2N3}vA`nCb
zfrwxb0d8=B32-UL4~uGO?1Houqb436?1@L6P~w4_sL6DT16rhGNjsnxAgB(*QQ{$|
z9Z*xCg`tF{gaK40HZ!^~G&82KAf+T$P)bT+OkqoBl4MxG3N8VgVwfQ%Ah^VXm4Bep
zkiCWjTK;i>%Req?I^nKi$l_eU1uErgxRJ`Yx*AS!`3EW=dBF0#DCOfkNZD7*Tf>^d
z4o*3|;1bUdoN_oIDF>zD11kAJ%?D7t6bmpgFn}5+u^hFGDGU>t3R!{~Rx<j5v%4nK
zEhasKTa3tUMsVzL*?>xl<ow(MyIckahR>j?q{@a+#K2-h&n727IWec$P7h&BQ5YzZ
zfC}CsI}i(;mie%y<y)*piFxU%Sd+9qNGl>qgIY^Ppd_uy2PsseK=N1;a!~|G0-Tgd
zO2q=OR1EW426`DN_aBsurBKR14p14WL2PJYL>;Jg4^FnAnxzC(k%FpfK2UVmFk~^h
zFvJSgGL|r9F)v`LVO+=vP7ER-*`jDro@RvvDmc%BGF*`vC?^Z1r<Q>0EO=u^lL-=L
z;Cu|uxL^Vt|2!~%gX)k5ke@;Q1_nkhMixdc#wry;egkQvk-tFE4JrLV^-+`v$X}4U
z3{-1E8e}l9flDW_ufPP@Yw&^zqN%tI!*>J=AgE4FrXo<OhFNccA`pZ@MR74GwSu}E
zEuh{na|#QnOjy8>!nzP#&eSs1Fu5?q%GEO0FoPnujIl_egbBoIW(1Kn%r*>#3N@e}
zGMgXP5Cvy)30Ni<sAa5SOkpTvEMlo)gk*M2Mop$$Yz37i8ToltTo78<Sg!(F0u;r7
zk_RJ%3l@X)7~>ci7$$&nO*{hwLzNZI7)NOgVa@rVtp4)<|NsB1nCxscnQyTc7vyA?
zKw3$lx(wRbh6E-kd){I!fj9&nYek@-><7gZsB^-=D8tCX$i*ncSfxgsyFi04nczW^
zcu@B}71XOJMrmt+;tPa978QeX3@pw-wE<czfr3}41Qb6E;8<b+$C3~@mYDIxk_;@C
zL_o2W#Q=(<EQS;YNocXG$q26OH5rRQ30IQ|95s-l3mg()0+fi0K#?;Y9yty;0|86>
z6<ZW(GJ#7`aKQnKjVzGypxXWxV=j7ZfEpUp5V64riwzx89HPkx$vL1<2h}ei42oUQ
zxCS^DrZIs=j23_kz=e#pj5SPfHd75VoXuRr2<g!GGBPsMFd?x)-Fn6pCXjdua}5Ke
zxNK&uWdZY8YnT=?HZ#_;f_ZE;%wQfHn8&`5p_#Fk9n9tcvpGsQYgn2Yi{_MYE#R(U
zS;&~jRl>D^2Q-MV5R_keQ&=S#;`vHAYgn5Ziw=QRv0_oh235t@%vkgStcndu6+6@{
z4yY>jX2uDOMFL=T>`3Z3q3XDx>NvpaEWqkGkkoNQ)$u^pF=X+BRpyj%EfA<-$l_nf
zxIl0r!vdj&40!@2Tv<FRyeZ5jTv@^?e7&IZh&P2_k|AEC1k`I21?yY})hS9)rvO~1
z7+B{6s7^70ItAf6#hV!?FcnEbb&3<zDFoLk0oEA;)hR(xr!ZKjWR}zd>4l&a$J@+U
z%UL41KxQFB4JU|Q!?{2fG_+j98O)$5;uoUHc#AzVuOKHe8Pd(n1PyoI;(+y*Zm~g{
zshJGS3=Fr}q3wqvMNqB+*SKmR7AT*;1U1(*8E>(slvEa^YBCprq;GMiq!uR^WhSK-
z7lB$}P~oB?kghyXH4jk-5y*!KfK0o^4JqeRQ{pr8U|K;EMa6Kfa<^DAQWH~(Oh6?+
zcTs9_YDs)%UP@}kEf!F-uP6khG7>~2f(TGsq$mZ%0+rU_-UDc~H!(T;7Hev5K}lti
zHb^oJL_qq%C8>GE`9(#UAoICEqg0vcrTL}Bw^(xW^U`l|frfJ+78Zf~DBxBdxW5eQ
zqugQ)xg}bZnwwvi8lPF5pOaXUni8LvpORX9OE?HD><JYG4W}31;t9zQ%}dEFN=+_F
zO}WL9n_7|x8iLZ~f;4YH{cK1F2h=!>;)6vH#K<TfNHjrtywG@paydbLX0X*!JdjEs
zCIK1!%`Zi-sBUq>M}+5s3P?~3m4T551eqWtA0r6zF!C^~F@c~4lLoUDBNQ?VF!C^R
zFjXnyDoEj``)P_4NrOThRH+onf>;)y;ASdFxy1>Z7jP>|%uR(v7F!4;K5)h>Xb`Fh
zlnlZ3Y!N5OG*B5|1ZwpaHG#O`x(7jkYavkUQXb?>A<&>Q6Ni|Hq=<xw1P2=jrwF?c
zJGhmmDR+ylq_QA0FCAh5s9#b9D)Mi!=j5lSXXd5fQUH&S#HXj`=cbkvWhU!^qE9aa
zJi`GVO#}~hM+u-!loUgTBn8lCNgzX)+)0VWnaR%id1;yHw>Z;NOMLRv(^HFXaTKSP
z_@tJl=G<cO^mFqEhv6;eoc#1#Yz2vVDT&2J%^>fBy2!UgusIA8Cg1@ha5#agtXop(
zDoctI^NQ2*i*k!^aUs}6wIH)WWAWgyL8?~4AqxuBTO2l!lx7F&D;0xcSA>y+i3gN=
Wm^c^(7`d20-3uNj0Y)BX0WJX4N80HC

literal 0
HcmV?d00001

diff --git a/embeddings/__pycache__/unixcoder.cpython-39.pyc b/embeddings/__pycache__/unixcoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..328777fd5508c3dbe7b68f74baa751b377c14a97
GIT binary patch
literal 8579
zcmYe~<>g{vU|?Y6_?LWMoq^#oh=Yt-7#J8F7#J9e{TLV+QW#Pga~Pr^G-EDP6cZze
z&6LBO%M!(s%NoU+%NE7P$l%V9!kog=!jQs}%9O?4%pAqx&XB^I!q&o&!j{U}%pAp)
z!Whh;$^H^#lb<H@ErFo?q|~C4#E|^#)V$29)S_EFP!Zq!l+>JCyilHVeqLH;dNN2G
zG6p#k#4iT9G?gKWF@+(DshuH>F@-UOsf8noIYlUiErq>>F^VOHBZaetA&NDHD}}p-
zA&M<UI7J{uu!S*-y`6!DA&MiIK~v-wXJ}q#L~?#gYEd#1$XqA}u{jwS7??rfv5kR&
zp@gA?v4)|UX#vwh21bSw<^?P@j0+iS7~)w=*x+pT6vi5cEcSSg63#4^8ZbMBsh6pS
zA)c#-VF7mva|zD^-V&ZHz7&=co-F<p)?TKCj0*%n4o+cT$h?q=k)cE|O9*5}mN1wU
z0ePaAxrQNLw1y#zD@!a(JWHa6Axkt~qC_xD66CdBkSnBW7_wNiq*Hj(K>@+r3yLki
zU<OV8s_9$`NWiHyGbcp>?q@wN6xo2pqQu<PlGLJND^v+Bg@jyC?8N6K=B6g-DrA(V
zr)TD+rzIw*Du5*vk`s$k6hN|i3aNVOdJ4Ij$wm3a`DrEkrFoeZ-~iT5N-R!=+ZUqA
zc#AC=9Lcv>k}?yEHJM&AurM$baWF72XfobnO)04?NPWr302Ol6WGrH5U|@Lp|NsC0
znvA!YQ!;a_*lltXi?a>vG#PL46&Iu?XC~&vmw*yhv7e^QE!N`Fg4CiSRt5%!TO9H6
znR%Hd@$p5VsJ+FXR+OI`Ur>}<Qk0mPmzr{m6RZuC9*aOptcVSyi7T@>J|z|G_##kB
zxWx+g>@EJH)bz~alGLL3q|&ss)S_FgCHY0k8MjzUiZXL<vE=8a7T;nm&CDyYxW!VI
znOc5}DKGC9n@?t5YGTnX?u^Wol+?WV;>@blTU=%N$%#o2HhWHPd`4<w$}O(c+@#c$
zl+3*J;#>UX`9&%5NCIr-shQ~+CAYYI^3#L!(@Js^D{e986zAUJN=!+CdR~+17E5tz
zPTDOluzSI=wvzD{XMB8ePGWI!eEdp=U#|KY`MIh3IT=ZY=|%?nxs@sT$)&lec_qdA
z=|zbJ879RgiOJddWvNAJIr-)K2p1qCRIi}2h@XLh0aUCMi-2+t2OlF3BM&17qY#q@
zqZFe6BOg<h5F)j~;@}sbO==!ESatLBaw_d!GB7bPxFv%k3S<fhgUkfQfioxr{$OBW
zs9{QBEMq9<@nUFVtYJ=J%4R5*0J9mg7;0J58EV;T7+n}*&(yNlur6S%VPD8t%TdFW
z!ko=ilu^Tx!cxPW!YaYg$|T89%TdFc#Z<#o!<xmM&RoMG&H!dHiZj%5<VDmlrLbi)
z6-}$*s9{cFmtf##NMXUQtCk~A3u-b4VUufE7qHZDEM%<ZtYIm7QNsz&ux#m!HOwXK
zHJr_i5)2{?;taK%HEatwYB(1%GBOmZ)vyIKXmTc2FfuT}^Nn+UURi2Ui9&HnQD$Dc
zLP@>?IOQs2rW7OcEY^&yrI3)BS5R6~oS>_alUZD%ke{Xi7J!;ntOvIQl;0C8;&W2-
z(n~TDbQMA}QWbI&D>8FSa}`SROA>Pwit|g0l2a9mQwvK|^O92)AW9Vy(@Ii{ATG+R
zN-W9D&qFb^AQ7Cb6Lb~IGg3=3z)BSg5>s#)SCS917U~WZlRya@Y)ocy2Dk`PD9J!_
zPI+cdjzVd1s-8loUaFn~YI4(2uz}`kNb=J`$V0{L;31*OT%^sw!0?L?RZEp1+(?)f
zaKQyu?DrB}gy@1wD^RiV5|pqtIg12AWeNv4x<GlKGdZUil3{Og7N-_KnWD*%2!?P$
z3gR<UiXr6@PeCFyi^pfC++xm2&AY{xSWu9fm!ip4qzW>f4IH7xw^%^obBik%9*MWu
zp<xek1sB*A#UP_^F;?7SgK(oDr65SSNE1}>f(b1U%aDPAK^~MB)j+wEn}?ByiH(Vc
zk%N(ok%LiyQ3jOf8AX@`7^RrF7+Dyl7<m}0ka9lQ{LHFUlu8y<(u0b0a2_vWU|>jR
zsAU4@;*?tEbcR}%8fF)USf5(f8rB623mFzLE@Y@>D`Bc(Yi6uvOJS^GD`O~ft6@uF
zs$pBeoWi(}5u8uiY8bLuvRG@_Kn;^#rdswImIZ8}@_><{P^5-w0b30VICpW>Fa<Mc
zvR26|z_JwBk)VtO%2m+(RRpSl{M<AdZ^<O*q$cLYmqH^7n!MwS3lfu4i#6GAF(((4
z-eS!w%`K?B#RAG-w^%dtQc^1*0nP?ZCgz~pjRO){T#!rzjs)(4qSTbkWKfPNF0uth
zB^N^E7Gvct#*$ksC8-r9=n><>z`)=HiWmt{#DJp&6fK}an1fM>Q2<;ZR|y~rVTeOe
zA_9~?;e~AssIX;Rz)-`G#hAskkg=Ajh9Qf&hOvfe0ShRQ!8FrCMiGV-=3b^07D<L$
z<{D<OC@V;G0V`A$YYOW^#sw@OQW_kYH7seMAZM#GRPawKNzBZH6ch@Gx<w%|FGZm^
zHLnC*rXYzIf$D>opsIT%<1N0#l9JRsP*BB#s()}46>)<CmhBc(UMe`R6q$ljtuQDc
zIZN}33rkZ|t5R<<7nkO0vVudu$Q2~V0x}><037p3wQQ6y%n?X}Xd!=#JuSbeJh3Px
z5ES~Lw9LT7$Ed`p!dNAU2yLjE5WgY=P;mfiwI+jt8WdR|45~IcVGYd-OrUzMNUw$=
ziy?(Eg|U?>g|UWl0b>diC~YIN)0ib0YT5FXY8V%QqbG}HA!99D4ch|N6y}ADV44LK
z8MSOR%qfhe3`Gf`NX%lZWv*e&Vy|USXQ<^U;aI>4QoVo+6h#ZTL3yN;rGy7$CRm&o
zWbXpLg$xV$7cw+6b~3}<QVXiR1Zo&-*izVfnQFO81Zx<xgi=@~8JZbg7$z{r{_13G
zXJ}_^XHH{IVNc;`;i%!t5_Vx|W~}8d5$Ry4VXk3sX0GL_;aMQMkYRz?LWWx28paOB
z4#pH_aR!jDJD58dB^g?o(wHO}7KnE+q%bdJs^zQU%96mZDusC=b1i>~U<VV_B}%md
zH9QL>7cxv>EJ|Wzn7~w6R3nhWn$0wUvFKP0KQxq?AT&5s1#1`<NYyYbkggG2$heR(
zOQx2;hG&6n4L_*VsNv0$Ss(`{<x^N|corxuWLU@qmO=7M-ijJNNUYWHLTGT!&X6Tg
z!;6SpxEfB7k83zvnLzf}aHnwgGG|HD@>A1wB^(PBYj_dioyL^Hw2-NmcLHOfPYurk
zr5c_E$_p75GS&*!@GMZN5n9NY!q@?6EeW?XwX>u#rEr1c5fY2SH5?06!Db4lur6d;
z0M6>5P$)bE7Uckm)^gN<`dHjmdJ68Td8tK-B}f%?dS+Q_o<eed9;i%H&?p8sCqN~k
zrh*?hFN1S@kvAw8gNooH9|i`7WDpCKYd{!O7=Q{cP#!N<0@ofj46#zRkd()i#tiDh
zws3&zlV-+=Ooc4L44O<BCV=uS*aRz>2}ZSy(6q<^NsEwt*uex%k<5||3m7^;oneL)
zhJ`FJlfmxw19fVOQVUWOOI9-7V$w4J7v(FNZ*eB2Cgwt#`HY#jU<H3pY96Fe0adZL
zxNJbRMsj{`fn6R01H)&KCx00tijs6_AnMuV<R>TQ6x-<`G$1j-CTcPjg@J<314M*_
z2t-vL&A`9_vY;5$N?~JU`d_6+oF$q}MZTaylfNLbxFkM1wKBdeF{d=O7$Z<Y1sX=6
zmM|<}1V;%2BuY9M(-}J$N|<UGVU6E~j0>1km_S(_8mKH7CW9h87*yVh!J>-~6kQ8I
zQ2~o&MsWJkWGs>d`8yOufISZ;Kp|ZO3V{xgIH;~GiUqMi^$SEFOOYtZlg!D9B}HZ+
zDUdPH@Vv!XpvhDON-DQls#1&c!C_s5X{M$cxUYJP1=Pta0*$92RZ9LK<9Sl^L2bM2
z)I3mAjx8m%EHgP3Vl$UdeqMS=YF=@E(Jj`r%$%I~TTCVSMW9N)C;;RjP$g322x5T(
z`4($YVqSXcE!N_U#DdhKIFRWqPN|8xxA@ajOPouKic<4Rf=d!hQg5*)=j7+57KMP6
z2<Cty5!M$4OS7baWR*b%%Yl@zq$HLk-eONoDND>t2I<Z(sEohG18&vC7pLZ=CYRh2
zK-l45l$oBHR}>G@%9fs5;!#;}iz7X?#4R%~G3S<G5@_flB(W$xwFEr&Q(OdY3yC0V
z<0u|T+oUKS)GLf)0|ilLN>K#JL7*@%0@budDIoJ%L5-v0TWpYSM-(5lI~ku}T2fG2
za*G{Wu-sw+<-1#);8roHbcy1I%78exnDa|ZZZXE+;zbUI;#<ral?At0KsJEmSCbXu
zgh;qMqCl)f5CMvMkS8I<8_1=_AcBE~g;9%%k5Pn)gOP)YjhTf}j8TS>i;<5}gprSt
z<v$w}sGF?9sKLm`$i>LU$i*neD8$IcD8<OfB*Li1$n>9s2__=L$j2zeD8eYk#DmhA
zh7||U9!il80|SGbCJVHsQ3T2*nw+3<574MgeEco0`1suXl+v8k`1o5q@$rSFi8&A%
za38)1)W9zSH71LKKptxZg)%R6z#us_Cnr8092gKn9i#;0ckti}D7zPfG7twNk0_Ti
z2R{cpSR7;U4pd}=5^pi6FozA^fd}pwqL@L=^%jOG7Rcyf6f3yl02;RAOyO!_h+<FS
zPT^@`h~h}$P2p={h~iA)PZ4Ngh~i2S3}(<2D#`-Y51<kaRBVDUhz;s}fqNIANC6GV
z#e>=&4DpO5ObeJn19%Hq7BVbgtpT-}7~<K|8EP2f*-MxfaDe4FK@C}mY&vL=jjIMU
zl*gR{B6+}Sc~cmI88n&wz)es$O}?TeP{4pvC?qR$yXE93mca5dsKkgbN(Y53TXAxJ
zQEKrmwt}M6GVkJ|RFJ+juuNWRMM-4wEtdGy{Nh`nJP?v!aEl`?GcU6^BQ-^n1(IJt
z7Tsb_%`XP$2S}_S3OLZH%tTNzDF6xyMm9z^CNU;9CMhNmtBMy?I_SX$V=&5CP{4r$
z4OGTLnk$l^vK!P4WdfD9ATL*Gx~G<a``nN$SFDhhUj*hSLrNTllFZ!H;*!(?J#euK
zt{yd+in>4^g_fwQApf!^C6**-pnD*jfq`KG$O9n9Ffa)*R*8TejVNuvB^rv`L5%}Y
z3In?xlpJAh2UYx4dNB7UB_?MV<Y(rAhB*m1T9dJ;lYxO@B_kxv!Ql%gz)8QDfq`Ku
z$c-RN8JMa>(Od{A4^cW=AXkA}|6nItFfuUIFfL#KrLtNk(C~5=Q!`@<Q#MOc8Yr!1
zF+)0B3s^y^QJkTcxdha=WlsUOKtOGCP#ZT^p_T<SR16ZWVUYxN+c>hBikv_rUNtN>
z424d495sw7pm8te8fHj~AKatn0tXLM3aIVI62nx>R?A+)R>Kg_T*95jQ^O7#2TNh^
zWv=C@VaQ@yzzZ6?Vg#AazJRxey@sQPNrXXy0o3Nh>MA>|uDSv0g@9ZM?m_S^;H}{V
z`<uO)F@;frp@t=m8DwH1NFCUpAa~a=WO0Do>n16zH4O3mDWFy*XQBu4sF*vrIayGY
zpQHd9vQz+#nnIFfYFTPgr2?q3QvlTk3JD-7pX>y1R}<cmgpc<mB$kxq`InX?SSdK?
z=N6Qfq=FTe7Nmfhub?p|uqnl<iABj7NW<s2M%<vb=_-KSR;-YprU0o<L1TBIF=%j>
z0lN;AR}?h76)GT(*MvJsO9AdET?M#{6%rxY53Uz6`WKX1Qd*Q(Y^4xVRGO-gnFjGV
z#K{Vo#R{M%W=?8JDq=tloIAmV7&x=0JM?U=kX<i&B;*!rT26j`ktX9U-lF`{JkXdl
zs1`}}(-eS|b>J2fD3cU{8&u%74QFaaL1JD?d}47?1}M9O^L#gm1<LrhSW5B>vLWRo
zb4q3zxEWB?3zF^u_2D5SSy60|kc;BTE6okhFG?wnVgrS|Pj(Tg9**Kl%SkK&WjLSg
zD6XW`;*wx+ff~hwV0)%Sftm?r-nSTw(aV<tP__o;UGQ)g4;M2ZqZqRUqX;7p6KF_F
zjERi}G;$)s1QO$7lw%fStl|M#rw56?RNN&GtkldG0Vi#yX3#tkb2e)*!%9X!O_m~%
zCm{_>u%AH%dC>%r@4+Ptm;e_qpeFiNkQ2eh3u6^0*m_VC4aIIy0SO;Odj+1U0*%~&
zGE_EGkpYMe86rqwlw>FYWndOiAtS+%!U)N1AQf4xklHtewU?0*qKXmhZnkXZqFd<<
zwTva~pkcKXHc5sBY@h-QT=dw0$A(xyqBTs=qKCOC3!KRz1rBI<f*IM?8YYOXg;_O>
zS)e)|G<OKLwT6L(A&aqwF$LsCHn1LW&j2*2XU@RH5DyARs4Pe<m_d^RQt^UQFG`@m
zssd0W2;4HMVE}obn<1ECB@<K=D5O_1-eOEdOB=U1Y;qEFlTs4xZi8Y4)J3S02ZueZ
zEsNX%gQx_}M`fp0f-?`X4k?-eG9KiLqM2|Om<x{n*$fN}k3ixe2WWC4r5I4N8Is#T
z8S55Haeh$=Br}3@8z`eeTM4(gO7mb12XMA2ng%ivYzdeE+f>THz;FYU(7+>zjG!?r
z(6}KBBP8w1FbXg#F+v*+upmIG)j)k{coM7vcg&#UxH&P*wam3Fkog(N6lxY|Oba}t
z0!@a@Da_5Fab7J@{j-1>l+0?_Q<%gVBpGU0CBZces0ONGhwy7z3r%W3tv5E%m`w_E
zD>Fn*A!Iz4rOFmwaE2%5WGjHGd(e1pa(-TMNl|Gsq|{6+&B;;7s4U1Y$w)2EM4I{V
z11BbML~61^iZ4VF2M&2qpld=#aHBXtMN)iGYH<`xMrA>9Q7x!N!vSmg++xfwst1WM
zW}&6Aq7nuMhG(EK2RG1ISU|l3MkPiGMgc}K#wu=Da`UJxKnZD3+X2)C2DcwT9YXks
zxe3yUxeTa&2i5M`Ohsyp424EDEDIQGSiv<bXat-Qc?7&@I;hZPDgq@Na8azu25Ie-
zgIvx58wV|F0Hqu|kQj4Gem1&)%Af-l#X{f)Cujs3gh5jQpq@^Z7}!S`T|hrg?xH4;
z!Ob9|1w??Fu0^2cUlFK%RRk)ii$FzT5h$60!>tHp8Il7KDH+rU02QUhpit#t<lyIE
z<zV5^;b7+A<KPAN5Hz_V6%Pk^!d5RYuSg2yC7zO^#Ju9P{G!~{qGIq2F?hZgJQaM4
u%@^El0(WU{vE?Uar6!kv4a4S!TO2l!UX~rGU<Q{x9E?1SBFsYd3>*M{oGTyz

literal 0
HcmV?d00001

diff --git a/embeddings/custom_logger.py b/embeddings/custom_logger.py
new file mode 100644
index 0000000..6fb2391
--- /dev/null
+++ b/embeddings/custom_logger.py
@@ -0,0 +1,16 @@
+import logging
+import sys
+
+
+def setup_custom_logger(name, level):
+    formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
+                                  datefmt='%Y-%m-%d %H:%M:%S')
+    screen_handler = logging.StreamHandler(stream=sys.stdout)
+    screen_handler.setFormatter(formatter)
+    logger = logging.getLogger(name)
+    logger.propagate = False
+    logger.setLevel(level)
+
+    logger.handlers.clear()
+    logger.addHandler(screen_handler)
+    return logger
diff --git a/embeddings/dataset_in_memory.py b/embeddings/dataset_in_memory.py
index 153055e..1ecbf6a 100644
--- a/embeddings/dataset_in_memory.py
+++ b/embeddings/dataset_in_memory.py
@@ -1,5 +1,7 @@
 import logging
 import os
+import re
+from typing import List
 
 import torch
 from torch_geometric.data import InMemoryDataset
@@ -9,9 +11,12 @@ from custom_logger import setup_custom_logger
 log = setup_custom_logger('in-memory-dataset', logging.INFO)
 
 class UserGraphDatasetInMemory(InMemoryDataset):
-    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
+    def __init__(self, root, file_name_out: str, question_ids:List[int]=None, transform=None, pre_transform=None, pre_filter=None):
+        self._file_name_out = file_name_out
+        self._question_ids = question_ids
         super().__init__(root, transform, pre_transform, pre_filter)
         self.data, self.slices = torch.load(self.processed_paths[0])
+        self.data = self.data.apply(lambda x: x.detach())
 
     @property
     def processed_dir(self):
@@ -27,7 +32,7 @@ class UserGraphDatasetInMemory(InMemoryDataset):
 
     @property
     def processed_file_names(self):
-        return ['in-memory-dataset.pt']
+        return [self._file_name_out]
 
     def download(self):
         pass
@@ -37,15 +42,34 @@ class UserGraphDatasetInMemory(InMemoryDataset):
         data_list = []
 
         for f in self.raw_file_names:
+            question_id_search = re.search(r"id_(\d+)", f)
+            if question_id_search:
+                if int(question_id_search.group(1)) not in self._question_ids:
+                    continue
+
             data = torch.load(os.path.join(self.raw_dir, f))
             data_list.append(data)
 
         data, slices = self.collate(data_list)
-        torch.save((data, slices), os.path.join(self.processed_dir, self.processed_file_names[0]))
+        self.processed_paths[0] = f"{len(data_list)}-{self.processed_file_names[0]}"
+        torch.save((data, slices), os.path.join(self.processed_paths[0]))
 
 
 
 if __name__ == '__main__':
-    dataset = UserGraphDatasetInMemory('../data/')
-
-    print(dataset.get(3))
\ No newline at end of file
+    question_ids = set()
+    # Split by question ids
+    for f in os.listdir("../data/processed"):
+        question_id_search = re.search(r"id_(\d+)", f)
+        if question_id_search:
+            question_ids.add(int(question_id_search.group(1)))
+
+    #question_ids = list(question_ids)[:len(question_ids)* 0.6]
+    train_ids = list(question_ids)[:int(len(question_ids) * 0.7)]
+    test_ids = [x for x in question_ids if x not in train_ids]
+
+    log.info(f"Training question count {len(train_ids)}")
+    log.info(f"Testing question count {len(test_ids)}")
+
+    train_dataset = UserGraphDatasetInMemory('../data/', train_ids, f'train-{len(train_ids)}-qs.pt')
+    test_dataset = UserGraphDatasetInMemory('../data/', test_ids, f'test-{len(test_ids)}-qs.pt')
diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py
index 95a6b99..0943795 100644
--- a/embeddings/hetero_GAT.py
+++ b/embeddings/hetero_GAT.py
@@ -19,6 +19,10 @@ from dataset_in_memory import UserGraphDatasetInMemory
 from Visualize import GraphVisualization
 
 log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
+torch.multiprocessing.set_sharing_strategy('file_system')
+import resource
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
 
 
 class HeteroGNN(torch.nn.Module):
@@ -28,25 +32,24 @@ class HeteroGNN(torch.nn.Module):
         self.convs = torch.nn.ModuleList()
         for _ in range(num_layers):
             conv = HeteroConv({
-                ('tag', 'describes', 'question') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('tag', 'describes', 'answer') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('tag', 'describes', 'comment') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('module', 'imported_in', 'question') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('module', 'imported_in', 'answer') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('question', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('answer', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('comment', 'rev_describes', 'tag') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('question', 'rev_imported_in', 'module') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
-                ('answer', 'rev_imported_in', 'module') : GATConv((-1,-1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
             }, aggr='sum')
             self.convs.append(conv)
 
         self.lin = Linear(-1, out_channels)
         self.softmax = torch.nn.Softmax(dim=-1)
 
-
     def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
-        #print("IN", post_emb.shape)
+        # print("IN", post_emb.shape)
         for conv in self.convs:
             x_dict = conv(x_dict, edge_index_dict)
             x_dict = {key: x.relu() for key, x in x_dict.items()}
@@ -58,29 +61,37 @@ class HeteroGNN(torch.nn.Module):
             else:
                 outs.append(torch.zeros(1, x.size(-1)))
 
-        #print([x.shape for x in outs])
+        # print([x.shape for x in outs])
         out = torch.cat(outs, dim=1)
 
         out = torch.cat([out, post_emb], dim=1)
 
-        #print("B4 LINEAR", out.shape)
+        # print("B4 LINEAR", out.shape)
         out = self.lin(out)
         out = out.relu()
         out = self.softmax(out)
         return out
 
 
+'''
+
+'''
+
 
 def train(model, train_loader):
     running_loss = 0.0
 
     model.train()
     for i, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
-        data = data.to(device)
+        data.to(device)
 
         optimizer.zero_grad()  # Clear gradients.
-        #print("DATA IN", data.question_emb.shape, data.answer_emb.shape)
-        post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
         out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
 
         loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
@@ -89,13 +100,12 @@ def train(model, train_loader):
 
         running_loss += loss.item()
         if i % 5 == 0:
-            log.info(f"[{i+1}] Loss: {running_loss / 2000}")
+            log.info(f"[{i + 1}] Loss: {running_loss / 5}")
             running_loss = 0.0
 
 
-
 def test(loader):
-    table = wandb.Table(columns=["graph", "ground_truth", "prediction"]) if use_wandb else None
+    table = wandb.Table(columns=["ground_truth", "prediction"]) if use_wandb else None
     model.eval()
 
     predictions = []
@@ -104,9 +114,12 @@ def test(loader):
     loss_ = 0
 
     for data in loader:  # Iterate in batches over the training/test dataset.
-        data = data.to(device)
-
-        post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        data.to(device)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
 
         out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
 
@@ -115,13 +128,26 @@ def test(loader):
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         predictions += list([x.item() for x in pred])
         true_labels += list([x.item() for x in data.label])
-
+        # log.info([(x, y) for x,y in zip([x.item() for x in pred], [x.item() for x in data.label])])
         if use_wandb:
-            graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
+            #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
+            
             for pred, label in zip(pred, torch.squeeze(data.label, -1)):
-                table.add_data(graph_html, label, pred)
+                table.add_data(label, pred)
+
+            
+
+    #print([(x, y) for x, y in zip(predictions, true_labels)])
+    test_results = {
+        "accuracy": accuracy_score(true_labels, predictions),
+        "f1-score": f1_score(true_labels, predictions),
+        "loss": loss_ / len(loader),
+        "table": table,
+        "preds": predictions, 
+        "trues": true_labels 
+    }
+    return test_results
 
-    return accuracy_score(true_labels, predictions), f1_score(true_labels, predictions), loss_ / len(loader), table
 
 def create_graph_vis(graph):
     g = to_networkx(graph.to_homogeneous())
@@ -132,6 +158,7 @@ def create_graph_vis(graph):
     fig = vis.create_figure()
     return fig
 
+
 def init_wandb(project_name: str, dataset):
     wandb.init(project=project_name, name="setup")
     # Log all the details about the data to W&B.
@@ -145,7 +172,7 @@ def init_wandb(project_name: str, dataset):
         n_edges = graph.num_edges
         label = graph.label.item()
 
-        #graph_vis = plotly.io.to_html(fig, full_html=False)
+        # graph_vis = plotly.io.to_html(fig, full_html=False)
 
         table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label)
     wandb.log({"data": table})
@@ -158,106 +185,116 @@ def init_wandb(project_name: str, dataset):
     # End the W&B run
     wandb.finish()
 
+
 def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str):
     wandb.init(project=wandb_project_name, name=wandb_run_name)
-    #wandb.use_artifact("static-graphs:latest")
+    # wandb.use_artifact("static-graphs:latest")
+
 
 def save_model(model, model_name: str):
     torch.save(model.state_dict(), os.path.join("..", "models", model_name))
 
+
 if __name__ == '__main__':
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     log.info(f"Proceeding with {device} . .")
 
-    in_memory_dataset = False
+    in_memory_dataset = True
     # Datasets
     if in_memory_dataset:
-        dataset = UserGraphDatasetInMemory(root="../data")
+        train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt')
+        test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt')
     else:
         dataset = UserGraphDataset(root="../data", skip_processing=True)
+        train_size = int(0.7 * len(dataset))
+        val_size = int(0.1 * len(dataset))
+        test_size = len(dataset) - (train_size + val_size)
 
 
-    train_size = int(0.7 * len(dataset))
-    val_size = int(0.1 * len(dataset))
-    test_size = len(dataset) - (train_size + val_size)
-
-    log.info(f"Train Dataset Size: {train_size}")
-    log.info(f"Validation Dataset Size: {val_size}")
-    log.info(f"Test Dataset Size: {test_size}")
-    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
+        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
 
+    log.info(f"Train Dataset Size: {len(train_dataset)}")
+    log.info(f"Test Dataset Size: {len(test_dataset)}")
+    
     # Weights&Biases dashboard
     data_details = {
-        "num_node_features": dataset.num_node_features,
+        "num_node_features": train_dataset.num_node_features,
         "num_classes": 2
     }
     log.info(f"Data Details:\n{data_details}")
-
+    
+    log.info(train_dataset[0])
+    
     setup_wandb = False
     wandb_project_name = "heterogeneous-GAT-model"
     if setup_wandb:
         init_wandb(wandb_project_name, dataset)
-    use_wandb = False
+    use_wandb = True
     if use_wandb:
         wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
         start_wandb_for_training(wandb_project_name, wandb_run_name)
 
-
-    calculate_class_weights = False
-    #Class weights
+    calculate_class_weights = True
+    # Class weights
     sampler = None
     if calculate_class_weights:
         log.info(f"Calculating class weights")
         train_labels = [x.label for x in train_dataset]
-        counts = [train_labels.count(x) for x in [0,1]]
+        counts = [train_labels.count(x) for x in [0, 1]]
+        print(counts)
         class_weights = [1 - (x / sum(counts)) for x in counts]
+        print(class_weights)
         sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
 
+    TRAIN_BATCH_SIZE = 512
+    log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
+
     # Dataloaders
-    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=64)
-    val_loader = DataLoader(val_dataset, batch_size=16)
-    test_loader = DataLoader(test_dataset, batch_size=16)
+    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=14)
+    
+    test_loader = DataLoader(test_dataset, batch_size=512, num_workers=14)
 
     # Model
-    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3).to(device)
-
-    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3)
+    model.to(device)
+    
+    # Experiment config
+    INCLUDE_ANSWER = False
+    
+    optimizer = torch.optim.Adam(model.parameters())
     criterion = torch.nn.CrossEntropyLoss()
 
-    for epoch in range(1, 5):
+    for epoch in range(1, 40):
         log.info(f"Epoch: {epoch:03d} > > >")
         train(model, train_loader)
-        train_acc, train_f1, train_loss, train_table = test(train_loader)
-        val_acc, val_f1, val_loss, val_table = test(val_loader)
-        test_acc, test_f1, test_loss, test_table = test(test_loader)
+        train_info = test(train_loader)
+        test_info = test(test_loader)
 
-        print(f'Epoch: {epoch:03d}, Train F1: {train_f1:.4f}, Validation F1: {val_f1:.4f} Test F1: {test_f1:.4f}')
+        print(f'Epoch: {epoch:03d}, Train F1: {train_info["f1-score"]:.4f}, Test F1: {test_info["f1-score"]:.4f}')
         checkpoint_file_name = f"../models/model-{epoch}.pt"
         torch.save(model.state_dict(), checkpoint_file_name)
         if use_wandb:
             wandb.log({
-                "train/loss": train_loss,
-                "train/accuracy": train_acc,
-                "train/f1": train_f1,
-                "train/table": train_table,
-                "val/loss": val_loss,
-                "val/accuracy": val_acc,
-                "val/f1": val_f1,
-                "val/table": val_table,
-                "test/loss": test_loss,
-                "test/accuracy": test_acc,
-                "test/f1": test_f1,
-                "test/table": test_table,
+                "train/loss": train_info["loss"],
+                "train/accuracy": train_info["accuracy"],
+                "train/f1": train_info["f1-score"],
+                "train/table": train_info["table"],
+                "test/loss": test_info["loss"],
+                "test/accuracy": test_info["accuracy"],
+                "test/f1": test_info["f1-score"],
+                "test/table": test_info["table"]
             })
             # Log model checkpoint as an artifact to W&B
             # artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model")
-            # checkpoint_file_name = f"../models/model-{epoch}.pt"
+            # checkpoint_file_name = f  "../models/model-{epoch}.pt"
             # torch.save(model.state_dict(), checkpoint_file_name)
             # artifact.add_file(checkpoint_file_name)
             # wandb.log_artifact(artifact)
 
-    print(f'Test F1: {test_f1:.4f}')
+    print(f'Test F1: {train_info["f1-score"]:.4f}')
 
     save_model(model, "model.pt")
     if use_wandb:
+        wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])})
         wandb.finish()
+
diff --git a/embeddings/hetero_GAT.py.save b/embeddings/hetero_GAT.py.save
new file mode 100644
index 0000000..865939d
--- /dev/null
+++ b/embeddings/hetero_GAT.py.save
@@ -0,0 +1,297 @@
+import json
+import logging
+import os
+import string
+import time
+
+import networkx as nx
+import plotly
+import torch
+from sklearn.metrics import f1_score, accuracy_score
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import HeteroConv, GATConv, Linear, global_mean_pool
+import wandb
+from torch_geometric.utils import to_networkx
+
+from custom_logger import setup_custom_logger
+from dataset import UserGraphDataset
+from dataset_in_memory import UserGraphDatasetInMemory
+from Visualize import GraphVisualization
+
+log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
+torch.multiprocessing.set_sharing_strategy('file_system')
+mport resource
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+# resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
+
+
+class HeteroGNN(torch.nn.Module):
+    def __init__(self, hidden_channels, out_channels, num_layers):
+        super().__init__()
+
+        self.convs = torch.nn.ModuleList()
+        for _ in range(num_layers):
+            conv = HeteroConv({
+                ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+                ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
+            }, aggr='sum')
+            self.convs.append(conv)
+
+        self.lin = Linear(-1, out_channels)
+        self.softmax = torch.nn.Softmax(dim=-1)
+
+    def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
+        # print("IN", post_emb.shape)
+        for conv in self.convs:
+            x_dict = conv(x_dict, edge_index_dict)
+            x_dict = {key: x.relu() for key, x in x_dict.items()}
+
+        outs = []
+        for x, batch in zip(x_dict.values(), batch_dict.values()):
+            if len(x):
+                outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)))
+            else:
+                outs.append(torch.zeros(1, x.size(-1)))
+
+        # print([x.shape for x in outs])
+        out = torch.cat(outs, dim=1)
+
+        out = torch.cat([out, post_emb], dim=1)
+
+        # print("B4 LINEAR", out.shape)
+        out = self.lin(out)
+        out = out.relu()
+        out = self.softmax(out)
+        return out
+
+
+'''
+
+'''
+
+
+def train(model, train_loader):
+    running_loss = 0.0
+
+    model.train()
+    for i, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
+        data.to(device)
+
+        optimizer.zero_grad()  # Clear gradients.
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        loss.backward()  # Derive gradients.
+        optimizer.step()  # Update parameters based on gradients.
+
+        running_loss += loss.item()
+        if i % 5 == 0:
+            log.info(f"[{i + 1}] Loss: {running_loss / 5}")
+            running_loss = 0.0
+
+
+def test(loader):
+    table = wandb.Table(columns=["ground_truth", "prediction"]) if use_wandb else None
+    model.eval()
+
+    predictions = []
+    true_labels = []
+
+    loss_ = 0
+
+    for data in loader:  # Iterate in batches over the training/test dataset.
+        data.to(device)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        loss_ += loss.item()
+        pred = out.argmax(dim=1)  # Use the class with highest probability.
+        predictions += list([x.item() for x in pred])
+        true_labels += list([x.item() for x in data.label])
+        # log.info([(x, y) for x,y in zip([x.item() for x in pred], [x.item() for x in data.label])])
+        if use_wandb:
+            #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
+            
+            for pred, label in zip(pred, torch.squeeze(data.label, -1)):
+                table.add_data(label, pred)
+
+            
+
+    #print([(x, y) for x, y in zip(predictions, true_labels)])
+    test_results = {
+        "accuracy": accuracy_score(true_labels, predictions),
+        "f1-score": f1_score(true_labels, predictions),
+        "loss": loss_ / len(loader),
+        "table": table,
+        "preds": predictions, 
+        "trues": true_labels 
+    }
+    return test_results
+
+
+def create_graph_vis(graph):
+    g = to_networkx(graph.to_homogeneous())
+    pos = nx.spring_layout(g)
+    vis = GraphVisualization(
+        g, pos, node_text_position='top left', node_size=20,
+    )
+    fig = vis.create_figure()
+    return fig
+
+
+def init_wandb(project_name: str, dataset):
+    wandb.init(project=project_name, name="setup")
+    # Log all the details about the data to W&B.
+    wandb.log(data_details)
+
+    # Log exploratory visualizations for each data point to W&B
+    table = wandb.Table(columns=["Graph", "Number of Nodes", "Number of Edges", "Label"])
+    for graph in dataset:
+        fig = create_graph_vis(graph)
+        n_nodes = graph.num_nodes
+        n_edges = graph.num_edges
+        label = graph.label.item()
+
+        # graph_vis = plotly.io.to_html(fig, full_html=False)
+
+        table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label)
+    wandb.log({"data": table})
+
+    # Log the dataset to W&B as an artifact.
+    dataset_artifact = wandb.Artifact(name="static-graphs", type="dataset", metadata=data_details)
+    dataset_artifact.add_dir("../data/")
+    wandb.log_artifact(dataset_artifact)
+
+    # End the W&B run
+    wandb.finish()
+
+
+def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str):
+    wandb.init(project=wandb_project_name, name=wandb_run_name)
+    # wandb.use_artifact("static-graphs:latest")
+
+
+def save_model(model, model_name: str):
+    torch.save(model.state_dict(), os.path.join("..", "models", model_name))
+
+
+if __name__ == '__main__':
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    log.info(f"Proceeding with {device} . .")
+
+    in_memory_dataset = True
+    # Datasets
+    if in_memory_dataset:
+        train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt')
+        test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt')
+    else:
+        dataset = UserGraphDataset(root="../data", skip_processing=True)
+        train_size = int(0.7 * len(dataset))
+        val_size = int(0.1 * len(dataset))
+        test_size = len(dataset) - (train_size + val_size)
+
+
+        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
+
+    log.info(f"Train Dataset Size: {len(train_dataset)}")
+    log.info(f"Test Dataset Size: {len(test_dataset)}")
+    
+    # Weights&Biases dashboard
+    data_details = {
+        "num_node_features": train_dataset.num_node_features,
+        "num_classes": 2
+    }
+    log.info(f"Data Details:\n{data_details}")
+
+    setup_wandb = False
+    wandb_project_name = "heterogeneous-GAT-model"
+    if setup_wandb:
+        init_wandb(wandb_project_name, dataset)
+    use_wandb = True
+    if use_wandb:
+        wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
+        start_wandb_for_training(wandb_project_name, wandb_run_name)
+
+    calculate_class_weights = True
+    # Class weights
+    sampler = None
+    if calculate_class_weights:
+        log.info(f"Calculating class weights")
+        train_labels = [x.label for x in train_dataset]
+        counts = [train_labels.count(x) for x in [0, 1]]
+        class_weights = [1 - (x / sum(counts)) for x in counts]
+        print(class_weights)
+        sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
+
+    TRAIN_BATCH_SIZE = 512
+    log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
+
+    # Dataloaders
+    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=14)
+    
+    test_loader = DataLoader(test_dataset, batch_size=512, num_workers=14)
+
+    # Model
+    model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3)
+    model.to(device)
+    
+    # Experiment config
+    INCLUDE_ANSWER = True
+    
+    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+    criterion = torch.nn.CrossEntropyLoss()
+
+    for epoch in range(1, 40):
+        log.info(f"Epoch: {epoch:03d} > > >")
+        train(model, train_loader)
+        train_info = test(train_loader)
+        test_info = test(test_loader)
+
+        print(f'Epoch: {epoch:03d}, Train F1: {train_info["f1-score"]:.4f}, Test F1: {test_info["f1-score"]:.4f}')
+        checkpoint_file_name = f"../models/model-{epoch}.pt"
+        torch.save(model.state_dict(), checkpoint_file_name)
+        if use_wandb:
+            wandb.log({
+                "train/loss": train_info["loss"],
+                "train/accuracy": train_info["accuracy"],
+                "train/f1": train_info["f1-score"],
+                "train/table": train_info["table"],
+                "test/loss": test_info["loss"],
+                "test/accuracy": test_info["accuracy"],
+                "test/f1": test_info["f1-score"],
+                "test/table": test_info["table"]
+            })
+            # Log model checkpoint as an artifact to W&B
+            # artifact = wandb.Artifact(name="heterogenous-GAT-static-graphs", type="model")
+            # checkpoint_file_name = f  "../models/model-{epoch}.pt"
+            # torch.save(model.state_dict(), checkpoint_file_name)
+            # artifact.add_file(checkpoint_file_name)
+            # wandb.log_artifact(artifact)
+
+    print(f'Test F1: {test_f1:.4f}')
+
+    save_model(model, "model.pt")
+    if use_wandb:
+        wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])})
+        wandb.finish()
+
-- 
GitLab