From c115151869a22cfeddae7c8933f8822518952b2e Mon Sep 17 00:00:00 2001
From: Liam Byrne <lhb1g20@soton.ac.uk>
Date: Mon, 1 May 2023 12:59:21 +0100
Subject: [PATCH] updated model configurations

---
 embeddings/NextTagEmbedding.py                |  10 +-
 embeddings/__pycache__/dataset.cpython-39.pyc | Bin 8383 -> 9121 bytes
 .../dataset_in_memory.cpython-39.pyc          | Bin 4543 -> 4616 bytes
 .../helper_functions.cpython-39.pyc           | Bin 3367 -> 3428 bytes
 .../hetero_GAT_constants.cpython-39.pyc       | Bin 1116 -> 1089 bytes
 .../static_graph_construction.cpython-39.pyc  | Bin 7007 -> 7012 bytes
 embeddings/gnn_sweep.py                       | 321 ++++++++++++++++++
 embeddings/helper_functions.py                |   3 +-
 embeddings/hetero_GAT.py                      | 110 +++++-
 embeddings/hetero_GAT_constants.py            |  15 +-
 10 files changed, 440 insertions(+), 19 deletions(-)
 create mode 100644 embeddings/gnn_sweep.py

diff --git a/embeddings/NextTagEmbedding.py b/embeddings/NextTagEmbedding.py
index 0d0ad72..6d5a857 100644
--- a/embeddings/NextTagEmbedding.py
+++ b/embeddings/NextTagEmbedding.py
@@ -162,16 +162,18 @@ class NextTagEmbedding(nn.Module):
 
 if __name__ == '__main__':
     tet = NextTagEmbeddingTrainer(context_length=2, emb_size=30, excluded_tags=['python'], database_path="../stackoverflow.db")
-    tet.from_db()
-    print(len(tet.post_tags))
-    print(len(tet.tag_vocab))
+    #tet.from_db()
+    #print(len(tet.post_tags))
+    #print(len(tet.tag_vocab))
 
 
     #tet = NextTagEmbeddingTrainer(context_length=3, emb_size=50)
 
-    #tet.from_files("../data/raw/all_tags.csv", "../data/raw/tag_vocab.csv")
+    tet.from_files("../all_tags.csv", "../tag_vocab.csv")
     # assert len(tet.post_tags) == 84187510, "Incorrect number of post tags!"
     # assert len(tet.tag_vocab) == 63653, "Incorrect vocab size!"
+    
+    print(len(tet.post_tags))
 
     tet.train(1000, 1)
     # tet.to_tensorboard(f"run@{time.strftime('%Y%m%d-%H%M%S')}")
diff --git a/embeddings/__pycache__/dataset.cpython-39.pyc b/embeddings/__pycache__/dataset.cpython-39.pyc
index fd97a912b458d313f1ee3b398718af41bcc1a946..5a89c09867736d07f5956e8241d511ed95476100 100644
GIT binary patch
delta 3046
zcmdn*xX_(9k(ZZ?fq{X+bA?(;&_>>+Oni437#LC+q8L*cqL@<DCm&$ycjZmxlVnH{
znZuYOpQ6yh5+$6<mm)rgF-0*&sf8s<B$Y2kW)5SDa*9d|OOz;7PBlfXg(XU?oq>fR
zN<5fBQ)6=pvnf+78v_G_GXn!du>u1FLk&X~Ly=4g;{v7<#w_LqEHw;StSO8OnHDkz
zGo&*FGiWmT-Qp}L%1=%$E>2C+WGfP7U|_h#lwW*{JtwocBqg)x7E3{5NyaUfto+Qp
zBH78cY!Z_>*_0T?ChM?mVmvqbKAUrW33C?f0=624EcO%@kkKv-u`#ubB^(PlQ&?*l
z7c$kd)UXs)bucX8D&ekSS-_LRwve%zaUmlk16U0^SPg3pYf(@McMU5{zLu?qt%f<9
z1x=K>FsPj&jWLBIg|mgDhRKDYdGb{DRjfrC3=9mDy*Q$@d2g}g7o_IhVk^i@&dy1_
z#gdbsm{O$4z`#(X1tPRTgbs+1oqT{pm+{u*cN}JNw>a`rD@x)sQz~w8m*&Akq(}v%
z(0H;FrvanX<YG>@dI1Io27Xu&w=sf(cmcygh8U(=##*Kth6Ri@j0+iSnM;^zm=`cF
zWGG>&VQywjVeDmcVVJ-e8xq4*%Tmh<iiR526edZAbjB3s8rBpRFpIU9Ifl8Gt(Lu(
zrG#w(dktF(n<PUGJBUnUTF6+-QNvKbfTM;rg<X<iA!99P4JXJ77lznBU}L#T*s?fl
zxKcPI8B*A?S&HVRaON?ku%$DWuw`+jaP>0Pa)T}Af?1rxoz9p7v6QEl7iuXF$hsO{
z5DB)FuZCd(cMVqx57^Fn{u=%TJSn_291EEk8Nh758a_CiA7VY6F93BcOs!xI7ufMI
zzEBM}NNow*0^S<_622_{1p+k;S%N9THB7S^QbZuSXEUUTLiEjMuuBn}%f65aq^X7}
zh0~m&mZ`9*P%DKym_bwg7blkj6lgNuVlBxpNzA#$;*_6K84|@Bl39|II$4%SPWu*H
zKz?zFXUZ+s;N<+G)FML$1_qG&l#<GVRKHu?{^fb8MWMy1MV={>b9rQBi$SWQxC={D
zi%T-|^Wrm8qBs-tipx`r;xkhwxAVwKM{&o)1>&oi^$JQfg^EC#@Rs!Cg+kJjMckll
z#gSQ(T2z!@UVMu=CpE7~5|m<9cy%U6@|t-R!IR}JUU=$@PsuFO6eyAbr9mf<3Cx)(
z6;b?Ui8+}m@i6BU-(rlv#R20@zR9bhQsm9Rzz`)2aZ+)7Nj_ZhEjEZ?k>6w&J`F~n
z$+>*$7Dk}riG_!Si&>7Dg^`1iiIL+k533k6AF}|X5+e&E-+vxv5RZqEhfx3wc^FL?
zt0X4tvq?<;$7jwK1ahG6<XSdywjvO_baDj0RX`jA0|Pj#GNv%KaFj4CU@T$CVp_ml
z!;r<2!VD=Wni*@EN?2={ni-3{KuI31f}@5Blqo=}YME=8gBeOVCtu?av-gL@6<2Xd
zVo^zPd1gt5CQ}i}1GhNRGILV%5_3~mG8V;v90bm~svwpL0|P@Ph+#T8TR<0DR;&?F
zb>hNaE`XGR{81DPVp)KQXi)HTx>h8o7J$-#CJQ81VkYwmN-3v;RDf){#hjj6QUNM_
z!9_l!3L^)j6cY<mmB{1*0g=h^f@+LfliLI}m1>xR88q4ae!0REEHvFJ1iSjUI)^A|
zO;!+5oGdNms=bo&7ISe)(MqT{87m-V5~}YyKrS$y+%4qF_lrvtZm1^9WJY08A#iCY
z!N9<<k{ROctjUJLiqZiL3=FLx$MS#@03#cd&_6b2mcJ}4RXUR=2<dE|EzHBnTI3E&
zrt3u#8GAMhik@S%1@S|QK%Ug(C@KQkTMQyV#@}Mj$xpt;o}8aknwwXAiz}tHASW|9
zu_QI+7E5w|L1j_ZWC?Lc#;KE|#ETfUCZ83Tj?7|CXGmdQ#F)a;3o1WB!NOYQ2={<n
zYDsd20wm2AE2QNYDI_Yuk~)|h;9#hbT9%konh1&(xW1~;$=`*&CPzr{)VqLE3=7yB
zMQ$LMdVmNokSOCVw&K#H;?xpN7IdGmr4;2C<RxlC1FQ+8rvXHO!yQb3BD^S+fq`KG
zC=B^PVaUS9%=DLyh2;+ms{mt_HmD4lJVi)&a=$po<a-j5lPd(>9jZ*=u7vmw<{yav
zaCoRF8)SJ7C@wy-i%zx^QqnBS1qtPY2(Wv=1lUoa^u7rsZpy&GP~{4W%=&mxDCybc
z<R>TQ6x-?9fU5W8{M-V&qC!wmutc#J<QIb~z$jjrgW{9&Q!1nQ;M@{W2^u8;=a&|P
zOF&Ilq`-qGwW3;(14=*y$YDH?R10$6Ed@}SC<k`X8l<3uCqFL6DnpQ)S;0y+zmmvc
zWUSa6Eak??Rs}LIWb%CJ494w~Ib<UEm>C!tSYg$%)8w<_lD;L(pz<VzIh(0SEuA5l
zp@s=m-bgZl%99k<6t-3-NrnZiDeRy&kS2#;RTw-z(DGnNVtVmp0TIJ`NOhrLt55_=
zOht^KvVa973$8kf3_zjGWM`wvh7=H>`~WTtZ?R;8as_K~K~81~a(aqlFG)-<j!#L0
z1Pv&ZiooF`o(4+w@nG4s{GxbeScsnng%+p*!oY|~XgrK8j8z(wCkTl%mn5c7z9S($
zIYFj?yBZWyjUb|J@@*ME#;cRHWNo-5ZgIrNg9^#`_{ll4(g}GW1=1k@a1<2f7o--I
zR2Hd&#B@P~9*6){V?~J|7Pv^v1hJSvgbj!Ql_5n4yFl5Y2vnvNwSi>2Cf}7cbprW}
zgHeQwiIIz$iIIzygPB8&Lzjb@gI9>1kCTs^k4K1`PlS)FNC@nZ`1st!%)I#cDlR=e
z{glL##4@AFesUHPMW70$ipLRL^D5+}Wae5G#e*`$GPxW!bC7kGlO^QkC2z5mWag&c
zVk=0@OGzw-L<y)Mzctxk{(=>#j7kUFR5Sr(9aBni5!g=|ARcQ$QD$DrEw+mIl+5Ik
hTdc(yi3O?P>a@sba->4AunQL_A0ro|5FGL`0szKl@>Ku;

delta 2610
zcmZ4JzTc5Ik(ZZ?0R%YyCFgJCUCP9_hk=11l_82Tg&~S5MS1c8rhZkv6!AHXDe@@_
zEi6&OsSuuGic$+plt?=R3qzD>FoUMb<|)jkOuozv3=GZ;3=G903=9l43|S0C0ws(K
zm{J%QG6plGGXyheGWq=y)YH>XNi0d!FDS}SPAx7@P0?g75@lduxW$xTe2YCNv$!NB
zv#3aI@<n0c$#!f?j9ilw*)}mAnk>TZoRY$_kkN%9Hl~)bgmnR13TqAHLZ(`l8kVBG
z4u%EnB^)&@3pi8Q7BV(7E@Wh60IOjKt6{BSEpjX2s9}Z4*Rs{H)i7tXpouaUx=lXL
zzLWbFC)|0HTREZ_xh6m0kekfQsl^yI*@DxIF=}!?rwpUv<Q`4~M!v~=Io;|(e&&G%
zfDjV{LkYtIMv&JT8A_NIFfC-LWvpRXz`T$lhN+gRmKiF>Qo^u+C55SmA&WVSwT7jJ
zQGy|jNsIyPdW0^v8m5JewX7I=*bsW^p_<rhSQl_CWGLaRVQpqiVeVyeVVJ-e+f>U2
z)y!DKki}KQmc^aUn8H%Smc^6Kn8Mo29K&48UdvI-S;D)3uZBH^O_HI814O2QEarmh
zg*rBc9g9QjA^zj9VapO&$XLsbs-=b@iyz@uh*N90K@N3ch<y^nQp;J(Q_EY+SHhbm
zSi_gXmd!GOvB)KbBabPCEuFE1H%lmmvzMusAMALLlUY)@;0_n46|Ci~2fJJV<mMVd
z5D9j<5Y!o<pck&;UC3A~3>AU;rG_C(7{g)03q(@5Yq%CNF*4ME*gQ2ta5irZKZsqz
zyFj!?xI`>Ve1Sv_LzZMp244;HY=#uD|7SDIWm(7sQdYy9!eP!(%RGUxP_3{dg)5jr
zQ{WdTmjV=22?anhL1tdM0w`1H>FHI8xRvH5mt^MW*{T-nD1>C@rrN3&Yck$qEy*uQ
z%n6yik6X^CN--e6xI`f}Hz_qG1*AEtG&3h9wMapuur#$8q*=2_pMik^q%EbSvLMy(
z77t7&J~cP#7FS|kad~PHh&|bXM^;-?>K1cOe)=tz%)GSxTP!7+xv51CpiIe{lA2VS
zev2bMJ})shH9r29^yGXYX{#b`kR(@TNor9^X+chE@h#?@)Vv}|kSH(QrlkCo%3FMJ
zZb@cIPU_?}JUSfg1^LA#@tG-;uk)CD-C`+7EXlaVl9ivCcZ(OE@8VN3i!?=wWI#qT
zl^7P;gBZ-2DHXSvN{piT%Mx=kQ{rJ}7vEwo$}gWB#;d`4i!r{)WpV?rgjy6g)R!Q4
zMDam6rNya5@tG;NSc{YMi&Bf+CNJXEV04{)npfT52UHj`b1`u+^DuKTN-=XVbAf3o
zCJts6Mm9z!Mz+5^tYS<YjAAStOgu~=H6n}(j4X^?|9O}}Vq9P`9!3yUo2<?k#OBYy
zz)++;xt>pg6~vr4c@CeIOArGCLpwtnV+vymQws;Uq+kY>Y|V_dOeIV;OwEi%b|uUU
zSV|bOSQoI>Fl4dUFr~0aGAsmTN9G#lV1^Qo$!h#zF7A*><|-~pEGj83&n(H%WGXTM
zg%n3xW=<+7&Q>xO#e>|S0wPpFgfRmH!&DH%gn@yfN@Maw0g=g9_(epjgfPlx{mF0n
zRqHv6G(pO>Km^DWMIiSVnS;2|AR-1t#DZ+)bgf8EEdVD#7I4&o?Y_kcatt^H-C_Zy
z^CGZgGC|rvF1p2>o?0>mRC9pqBSr~E4n`>^7N#nZ(Bjl0_sP=)OeecbSTXxmrA#&y
zmrwvDe$Nyg1*iO!N>COl$w}2w@Gs9xEdr_cOqt9ttT;JVj(xJdpq+sx<1Oamk|J2R
zFjhb;2fF}dc@Zc$mVlfB4i1OOijq2$ZwiXkS8-{=bE+mwQ7A}B0Eh@=U|@(6fF-?n
zXoAvYy~R|Ja*MAhH8CZ=xG)ElHj65Yz+oc+3LEz1{G8I<ykbpeNXX=Y1mW4TI6fr}
z#bpW%3`M>S3=9iF;lc+>nT%{qLjTy9S^l!Ha4}ZtfE>EHONfV&;}#nvL?+J`PGnrU
z`LFOfM#~BY28NKLN>KXXC@KOuyBI`(UBaA`pIlS|@&#8)X+chAa$-qp$}N`U{DR7&
z>d72pj*Ocodx;e>Dox%iCe073v}&1Zn1UHJnf<DAKrtWz33;R-*HH)tWv9vig$yTG
zi}OS{f-Gi1vp9+^r6|83FA)|3jUX4*g9vbJfC+GvgfK8L>_zf7Gt*x-7M4FOtQ?d7
zi(7B@mx^X&tlYd>+KrK|8l*jV@^_gG#$%HcWFsb<$QN+efCL*rMAPK;@_vlBCjXPS
zQ4%WxW$;^E@$tF&DWy4}(u^lQzOXbg2b53Z<BJL=Cn`wTYJ>FYfCx|tSd<E4fm3}h
zh{XgVtU&}Qa*7Z(f!tiwH2Hvnsf_{y0|N)62p1D07c&zh7b^!dhZu(^2Qvq+5C<P8
zA2%N-pGc7qDEc|#<8u=;^Wx*HIG}lNvbmy#!YvM1>Z;;#1eeqbc`2E>R=3z7h1leF
z#T+&>kRvQ6|5TKhDDnas##WG+my%cv33N~vjhbwzbU~`f1EdgS9mM}kDaA!#ujEXY
gR+i#+0m-_82o-RfD_GfD*prKskCBT}2oCue0Xe0iuK)l5

diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc
index b144eff4b8fc104d87e4e3c6da288641d1083a7f..b0f15c4568700226cde4638c04bf1bfcf832dee9 100644
GIT binary patch
delta 1172
zcmdn5+@Zpk$ji&cz`(%Zxk4>Pgl{6B4C9@N+H+iQqzFZcq>82pH#0_wrHZA9WHS|Q
zOBHHnU}VTsN)<~H&1NdPk;<1Unj+TB*vuFuo+^<lm@1ws(#+J%7$up)7|ft4zL|}2
zE{iHV0|SFI0|P^`0s{j>3S$aW3quLR0>%=CET#p_H4IrSDa;F*ChuXru2`hTz`&r%
zTBOLpz;KH_C$qRDC9|kViGhKkNEt*(O}1s1*nFDJfQeCO@*j>+M%~FloEP*=KypGL
zJsbr^`30#(C6z_0ATf0ip#vgxL4+QNu$>&h6;;m&@-h^&FfcH{-RQ%>z)-@_%uvf%
z!nlB`gfWYG0ZR=-3gbeig^Vr?u^P2ZC9DhBQkZI(7BbZ`*Dx1J)i5q#FX5<RUci~c
zypXY(aUmlkL!nd&R}Et`C=giuz`iQd2U%_a@*nptX3yf(60k3ev_ZlM9~yzU#vrdS
zC#IwnnSi*aATA3iG&EU?q(CAd*A`iWST>U*xsACb7<m|37{wTiJSI=$Ue9PS*^MW;
z-U6hJwJ0$!Jynyb2;{F@jCn=eAYI-df(I;LT#%Dl;tq-(P%tns3NaQ1GcYhDGeH9f
zM6oh3FmQkbLCT`Q0R-|~EfXk+YM4MCE@3WVsbQ*N0L6YSLkVjQLwz$SymVriYME<U
zO4va0Uc;QiEY2{SA%&%cy@sioaW+E=>s;0vmS*NK21bTLjTDAp22D1<Dh@q8{glL#
zMExq(wEUbD-6Bx16&W%xF#KZEEiBe6DAD97@&GxXttc@sB|rBTdvQi-T3SwO5y&TC
z58Rr(lvh9!6b72Cx4^~}BZZyU<deKUTpWyCOni(2j75=?<@sbKrh$D{!r07M%TNOH
z3u6ic$X}C#_>>btk<L=W(!l_7_FRacL1L^WETFKM%>Wf+FJVk!gQ-hlpUd3A(9B%R
zR00y`fT;zE*Dy9S*D}{I)i9?q1v6-JPQJ#M%@{e^i9cJnimjw5F*8rMD1?E5p^CL6
zwYWqV=3|y1fBz7OH^GU+V)9LXcgCp6iUN9COdQNIj6zHzj3P`Dj1r7R37{b2h>y=r
z%*>0Aui}CRYnhS$<XnMRF{xYZIr-`7nR)4uq!<cH6-JXE2za~4fS90YhonGISc5~g
zC=w(AGPlSM#EJ&lz?zv?P+D?}Ejc4UGdZ;=2qXcHhlt6!f@*qNpwz?20fszG9E?0n
ZAShrU#L2<T!O6kT!N>xZ1F7I)1OTSz`+NWZ

delta 1098
zcmeBB*{{r($ji&cz`(%J7Q>x#iFYEO4CA$l+H>qLrHZ5oH8Vzuri!KrXEPOTN)>8m
zU}VTsNEJ;H$!03Ll**SXk|Ns7*vuFumMWeqm?{QlOQbLcGiZu!=3$)6GFgt5Q$U1)
zfgy!4g{g(1gkb?=3iIT1tk+e43F_(TrzDmn>K7E{C#Mz{r>1B!6)7?>Fx+C#$t*5O
z$t)^Tne50Wy!jrR0TZL{WPZ+2M!m_YoEP*=85kIfgg`1e3X1XzQj1C|i&Q~k>L5ZF
zMCgGCeGp+cIhiY}9^^e3W@cbuU}s=paAsg&D0X6CU?^c|W~gN-VO+qJ!nlypg&|g>
zma&9+0ZR%~4dX(lTBaJNB7qu)1*|1(HB1ZGQ<xVrHZv|{WMn85DB-AKXa@O~#ji*k
zWSaqqFl1n0(B!<uT%1}0_E3=ySlAdOX999Ib7D$L5!e@IATA5YznaWNAV(G1fK=In
zJU)2_w*nU*BM&1BqZnh6=j21&>lqCv7x5&^Sb}u17A5ATr)n}4`GR<ic}3hHoj#L4
z^2pYQFfcGAGr~N~z`(%5z`(!(5(N1<0_JB%ke_S7ekoxtVX0xPVQ6M*W~^l>VFmfi
zg&|fahN+gRmbrut6n`~LDa_&wvl&uYO4w@{n;B;_q_ECqtzm9v4r5?sC{#;f2xib^
z^ZUh-0gKuy*0lVb6x||Fh!hzyFfjaL(=9C4D=5+ADDnh3o2@7@FC{<s7JG3<X<AxN
zYLN%X6Cmf`;!Mpe%}p&zEJ-b51qp&Yt;uo=Y+^Aath^`x;q|d)W8`AuV-#R4iiUa+
z?r)H<L3%+kUOWZnYlae#KN(XPN|;KRC%5w{C$cPHO<^fv?qC4<X)Y5eZc9L7tR>74
ze?rCBQrJ@1Vd_#i=7MxE*D{rW#5rMVLE<%x&CIpTHB2?kX-vTknp~5)`Lh`#CYSSP
z>lB4DFfdfHl@ujr=IIuNfjF!ssl_F_n(RfOAc2GcIB{4`<`-~hjF}uRpeM)3#w^1q
z#3aHf!X&{c!B~_C3O$ba_}s+Iy!iOZ>jh%PByX|j<fo@+=A}bYT{tMzj3+Azdb`Jh
zn4pk=q&l!Ypg;x(8Q82Mdys*wnRx}JCAZj;Gx9T&Q;UK@+Q2aq1!8f7Q{@>!K1Pkn
Pw*}>R`8XJ1kcSZfFU<3e

diff --git a/embeddings/__pycache__/helper_functions.cpython-39.pyc b/embeddings/__pycache__/helper_functions.cpython-39.pyc
index c703478f4863918e0f7707354d90e4d738152c74..0235bc1e59c4ce3b5175b5dfeb5eb5f57eba5468 100644
GIT binary patch
delta 973
zcmZ23^+bv<k(ZZ?fq{V`c8hLGB+o`ZDMrRylQkK?G4fAVV5(*m+&qWLkC9Pq@@r;G
zM#0IdEGcY43=9lKQj_~xB<;9COg;t%hFk1;@p<_vsl`$3dGV<!>8Zs<$_xw)Q35H6
zC5gqUCGm+xC7Ef7$t6YnAT3}*VDd8-Q%3E{GOY2MA`A=+n#@JAAW1opd2Gq~d1;yH
zMLHl6uyQaVFnI~9HlyC;3#`3T1|W6Z$vKI|#qoKGxv9mPED!_aCzrEna)Z@_34zJ0
z*%TNpCZAzb;o@arU~pz&U?>im{EN+As)lhkLkj0yrW%G6E=h(=)`g6W3@Hr344T}N
zUD;LYIrY;Fb#oJwi}I`Z(hPNrlk<yGp?n^YRC#J<dPYfVN|hj77B0l5Uy_)VlUk$+
z@){G=Wngv`cTs9_X--LTd~RYvktLcBLB0~n$xn|jN}ar$U04JZybO#ijC?ToyU1qp
z1NO-}wpcWB6(^RZ#^>gzq~`d6e8~igTE-%`$@4kNl<Yx9vzC@*<`fr6g4jGoiFqmc
zx$(sXIhiGzY(-L&%{g6B0<N7?RyYV`5C_;G0Y)As3C1Gd$=f*(7=Xf7lMx)xE17OF
z=^5N&tU!)lkPTcmIhn;J$@#ejb}^Heb7k>EG~HrN&M(a?De{}Fz%3#I5-Sn_`6dvQ
z5J8Nf$sXK3^`W5PV!Xv%T$)=1@{T6sEjF;mVu(F(J#k<?Ak{2I@*sOzGV{{%i@+Xc
zD@n~O&M(qrDv|=p@In#{B$^=RXfhRrgA51(5g;>fi6<xKB$rO!#w8RU3GzJ;$oEWK
zjC_njj6#fDOhwTkyTosC<fWFB=NDyH6mf#|g6$J6$uCOIh)++=&rK~U%1qXS1UM)r
zAc?>U<c!HRJjPBSrxpo<xXi^RMMc^mGxR`&1&FW#5pE#D7es)pgm@bqUSL9C@<$#+
KRz5x!4t4<ch~R1f

delta 925
zcmaDNwOooXk(ZZ?fq{XcLXkh^H}^(9DMrSq$(oGc82KhEFjX@OY@Wm9$H*u;`8Bg8
zqrhZUmJ~KY1_p*A$;tgJl6D}SMZ63Q3{mWP@p<_vsl^~hYD#))agh>8SRf^_B(XTP
zBtEgIBr`2Bxul2>q#I1|PkzQ?%BVS6hBaPOn1O*oletI+Bq__lz;KH#IX^EgGrdR)
zBm!0rCio{WVbx~TnS6n@S4t0L0e5mvVsUYNUSe))u_g<|0J+KKY?|C)^<aX3@@h5(
zM$^e>*i^W<85kIx85kIf-6sEHv**cVS;)x9kirnmpvgHom|ex4SwGFNiX+WXw>UY!
zD7A`Bza%j!C$&fe<US^-@nCiocTs9_X--LTd~RYvkr}!(U$e^z@qogWk%f^D0{;|Q
zPL}1EtYd|y7UW&7;>5Dl_}u)I)Er-s>zKe=i<~Fl;V4tG0jXduEy>I&E|LJTd5RMA
zQu1@-iwklxOElSvBqx_}x>|$%3MTkLiX_0Y@g=FnCGjOiiJ5uv1(hZFMada~APYFa
z76>r%Fi9{Lc}<?kEi##dYmEU&yCx$z?p8A0V$w6X#aMwHWgvsOY;rP-OOo?*3+$pN
zKjF$UhiJORnw(#nS5o8+3LeH=JP<o`5|dJMiXldToK@t<z`zg#3IYGgt=v8VK@io<
z#ihA0oorxp(DlTE^?+2f6v=^{!IGJmmR|%88@7_vyyE;KO{OBr$%;I};$R0c6@`Gb
zfP7m7viX+y<ab=+;SnI;gJXt?i;<5}h*5}<i>U}CswsYpBQLe2JijQrq6ic|V8f$C
z!GRW^o|>PVT2hpmtOtn*Pz*uRgFPr6PTt95>;!UGkpReX%*7=|MVcV{bwGqEh_D0^
eptM)y1!94$goF#oxeylrWIbL(R$e|94t4-`Y0h{6

diff --git a/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc b/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc
index dc4dbeba04cf288cfd9bf78b062a95b945be411b..4fb58f303d20a0568ce0f968a6241a656623153b 100644
GIT binary patch
delta 213
zcmcb^agc*Ik(ZZ?fq{YHV5n(|+eBVj#)^sBHjXS&d?{Q}{3+Z~0x3LEf+;LfLMcp9
z!s!etyo(s4L{j*oL{s>q#8L#J#8U*LBvOQ;BvTlpq*8>Vq*GXe88k&EE>U3gn|Q8X
zOoV}fA^p$FkC%Lt0_@Y%z~G(z<YvYa9VU0@m&^<d3@;rR7#OOAf*svm<BfCT%?wiG
z%}g@%3QE$A!QisJrp#n@rf1w-L9RaW!J$sUt|60onAfqXFfcF_sZ2h=oWf`}S(Qa?
IvL}l$01HPs6aWAK

delta 256
zcmX@eafgF9k(ZZ?fq{Wx(-O55(}}#Yj3pDbZEU%s_)@r|_)~bI1X6gT1XGxzgwh#O
z_!con38(N!iKGZbiKYlfiKPfdiKhrhNu)4ENv4QINu`JeGiZuUT%o{K#o-zeV5nPE
znm6%ey&1@c^gk;<Uh+)}uuo3|gLn2-a(a6Dx%nxnImP<vi6!xciSb4Tsqtne8S%zB
zdIcqu;}}Z}nB5&iUNSQ<FuZhNU|^^cfhsgaR+tX*00{eON>BdB_-yhArnPL!3=9lK
b%99r}r!bmK=3r41WM*OHU}RxtVT3^d?~Fq*

diff --git a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc
index 2d145e62ef1474e70b94717c8d6774b17b1e19f1..1ef4ec5d29e091fa7b326d0c1910abf0eddefc42 100644
GIT binary patch
delta 1419
zcmca__QZ@gk(ZZ?fq{X+bA?*UoQ=Fond{jY7#JKF7#NE8FfcHrGDI<^Fhnt>GH0=*
zFs3l2FsHDju=cV<v8J+Rv81r2u%~dOaP~4sv8Qr0GemKwaHVj!Fhp^sa%b_R@@Da+
z@XTRK;Z5OdVTt00i}BB4N)bpAY+;EKfQt#uVM-BB5ouxZixNx`O%ZEhh!RQ>PmyS0
zh!Rc}Nfk{Mkz`1boWqnNl_K535+#-@nkp{IkRmgOF-0~-u7xE^B2_X)ehy=bLW*Jw
zOO#ZKQZR$2@-6Yn2}0tNSy`9XCo_V)3B@2b69WT-GsvHQ3=9mJ3^fe#5R$QksTss$
zPGPKJh-XO$v01?+8<=DVlN?}@6HIcYFx4=`bEh!ZFvRnuFa$GbviQB^U|?Vfc`3rc
zz@W){i#vYu1=i<Uj71=Rkt72H1B8%bU|_h#5g(tKmst`YuMBdH1jvajn|0VD8QCmA
z3N0shb4X3z%O){-6-Pd!*=8wDK}JUN$vRwrVe<Oi5>{3qr8Xej_={{oYy}Wu2O^9?
zgguA=MQf2Gh@}G}^gx8u<VJ2`Gf+Sjxj?jo-53m#a|ID@Ai^C)cz|5Xm|5fn;&?MK
zFibwcEkAiVKmX((-1&^=lQVdJGg?ee;ayrE0#X4A<RWVj3+ygFZ0@?nT9lZVo?3*>
zdHNu&P9Q=PL>PexEszuWAi)>`62RibqA*a{gIx)C;6A?fZjm6ptPod#O#}teE#Z>H
z^!U`=q|}s@%)IpY-29Z(oLhqFsU`6!5}Hg9OZ_J==3k_XY!t{zMPQ@Ea`RJ4b5hYw
zn;b77#~3iVP9T93l+=pMK_NF=NNO^p;6k-1aF|v=k^?wlfYL-!3<CqhM3Dbt85kI<
ztR^Q2sZD+-D5?hv$(R5C|NmdbWM`wve2cZXASbh=2<#nDv}>|JoC&tbXR@J?7^CB4
z521Hj0U*ya6@jcP3IiDowhK(aEt~-okDZ(#>??$9Ru)Luck+H=F-C{U*M!{}%{R-5
zd}E6A2Z^hJ2$10~LHS9O@fK@JNo7H*CUX%;`W9zOYH@N=W>RW#Q7*^@5aFUCkghy*
zfqaNSz~o$U3C4iQZQ^?JK_G?UAR-<_B!P$&5CL*vk@Mus;_JB0Kzwj|jhNgcajssv
zNE##}10rNW1Uo2<xpMLoQy|e>lmHS1m2*YOAQq?qECOd0Q;<rgf|OgFE{P?HZbgZ?
zsYRfyaf>Y^HLp0os0gQf<Uj_2f}n^SoWVFje22-qB(?cKIhun}h*5x%jgdoi@>|KH
UY%L&J=gB*y<QeTI-;int03ie~Z2$lO

delta 1486
zcmaE2cHfLQk(ZZ?fq{YHPXc$!nvJ|mnd_Mu7#JKF7#NDDFfcHrGDI<^Fhnt>Fs3lI
zFhnt@vShKQvSqQSFwbF1VM$?aVTs~^i?PjNN?}joXkm%sgo|;`VM^gj;cj7x;!5F3
z;ca1v;!fd9;csDx;z{LA<xAz2WJnR1!;~VJBGkeX#h=Qjk}4p{kRm*XF-0Urw1p)~
zFjXi;Yz||Jc#1>|OO$YmWH5uK)Z|(g$@*kQki(&vg@J*Aje&u|8RYUZ1_p)_hGvFZ
z#uCOXu-j9ZL9Xp(sbwl*&SFYo1GzJWvzNJ+Ig_D=A)W<HvX-zlgLv#IEHw=A9O)o7
zCz#{{liXmE2Tby&OMs;LQn+du;`vj!YZ&4MQW%06G<p0cPh(lC9P$zrteS#FAgYL+
zfq~%`S5AInN_=j9N@`9K$f=XJv%X?xDH56N$RRCyOSmL4Jw7!zDK#Y}GcO&YY_g1i
zgjSIx0|NtukOCRQ5g(tKmst`YuL$z09LWDnY>YyT0!*7vvqdtpS%Z|@OxEI%Vl6Ud
zU|^UW!;#Nux%mu-AS0vI<a?ZdVJaSSNz~hd6x%Z}Flh1@fg-&~0VDv@R%8rfIf4i$
z5CIC<A{`J*4@9_t2u%=S2qLsVge!<}gBT_PVuykVcM#zLB0NEa7swBcnMFPz&g3R;
zNwy#mD|qsB?j}a7$x=MO8LcOa@h+_o0||qoxX2E~0y~lqn<H<r7A5ATrxsyzt3F68
z!mUOiF37E#d_`avMu7ydxUeWf86*RCq#uao4<Z6UL?DQOx;d8bkW@6t!>ka8fdd2_
zgl3?S_vDvm44j<GzsM3<5y*Z>IGZytFhq%@r<Q=CyEG>i5&xP@5cdU54i-q@1lwx`
zGO<8NYVtLKg=#V2P^^F?3vdDfrH!IE1_p+SAos;HFfdeEO|EAVpX?*VKl!krD5L%4
zn}X7sph(taffx=p7Ua8Ij3wwMIZu`qdZ!QarY2Jn$hM*gkhx&1zy#dB86a`6eGXvz
z76|Z9P7qc?u_+6r1{8$17<18Wa+<ta*qza8^Jn32O!eTvRRgI8S^bg;6t;}FSW`+W
z3sN<ii$K!1I8#!KlZ!HwQj3dnLGlpcq9Ty4JamD4hycj6Tilts1^GoKsVVW9c`&UY
ziK1ed*2x*-it-^KUL=S}1QE#~A_YW%!mh}5@(S^F+?F6dIC)1+PLw#8AW|d^5|IHB
zvLM0&6hTY{DYrOX5=#=@iV|~Ei$M9~7F$SaUU7a=5l&k`Il2hs@gi`R;{=)JG<m0_
iHir-c0|N&Whxp`Al1JH^LE^5H_e#k#nohne)eZmwk}hrl

diff --git a/embeddings/gnn_sweep.py b/embeddings/gnn_sweep.py
new file mode 100644
index 0000000..4d5d830
--- /dev/null
+++ b/embeddings/gnn_sweep.py
@@ -0,0 +1,321 @@
+import json
+import logging
+import os
+import string
+import time
+
+import networkx as nx
+import pandas as pd
+import plotly
+import torch
+from sklearn.metrics import f1_score, accuracy_score
+from torch_geometric.loader import DataLoader
+from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv
+from helper_functions import calculate_class_weights, split_test_train_pytorch
+import wandb
+from torch_geometric.utils import to_networkx
+import torch.nn.functional as F
+from sklearn.model_selection import KFold
+from torch.optim.lr_scheduler import ExponentialLR
+import pickle
+
+from custom_logger import setup_custom_logger
+from dataset import UserGraphDataset
+from dataset_in_memory import UserGraphDatasetInMemory
+from Visualize import GraphVisualization
+import helper_functions
+from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL, REL_SUBSET
+
+log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
+
+if OS_NAME == "linux":
+    torch.multiprocessing.set_sharing_strategy('file_system')
+    import resource
+    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+    resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
+
+
+"""
+G
+A
+T
+"""
+
+class HeteroGAT(torch.nn.Module):
+    """
+    Heterogeneous Graph Attentional Network (GAT) model.
+    """
+    def __init__(self, hidden_channels, out_channels, num_layers):
+        super().__init__()
+        log.info("MODEL: GAT")
+
+        self.convs = torch.nn.ModuleList()
+
+        # Create Graph Attentional layers
+        for _ in range(num_layers):
+            conv = HeteroConv({
+                ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+            }, aggr='sum')
+            self.convs.append(conv)
+
+        self.lin1 = Linear(-1, hidden_channels)
+        self.lin2 = Linear(hidden_channels, out_channels)
+        self.softmax = torch.nn.Softmax(dim=-1)
+
+    def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
+        x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag"]}
+
+        
+        for conv in self.convs:
+            break
+            x_dict = conv(x_dict, edge_index_dict)
+            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
+            x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()}
+
+        outs = []
+        
+        for x, batch in zip(x_dict.values(), batch_dict.values()):
+            if len(x):
+                outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device))
+            else:
+                outs.append(torch.zeros(1, x.size(-1)).to(device))
+
+
+        out = torch.cat(outs, dim=1).to(device)
+
+        out = torch.cat([out, post_emb], dim=1).to(device)
+        
+        out = F.dropout(out, p=DROPOUT, training=self.training)
+
+
+        out = self.lin1(out)
+        out = F.leaky_relu(out)
+        
+        out = self.lin2(out)
+        out = F.leaky_relu(out)
+        
+        out = self.softmax(out)
+        return out
+
+
+"""
+T
+R
+A
+I
+N
+"""        
+def train_epoch(train_loader):
+    running_loss = 0.0
+    model.train()
+
+    for i, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
+        data.to(device)
+
+        optimizer.zero_grad()  # Clear gradients.
+
+        if INCLUDE_ANSWER:
+            # Concatenate question and answer embeddings to form post embeddings
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            # Use only question embeddings as post embedding
+            post_emb = data.question_emb.to(device)
+        post_emb.requires_grad = True
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        #y = torch.tensor([1 if x > 0 else 0 for x in data.score]).to(device)
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        loss.backward()  # Derive gradients.
+        optimizer.step()  # Update parameters based on gradients.
+
+        running_loss += loss.item()
+        if i % 5 == 0:
+            log.info(f"[{i + 1}] Loss: {running_loss / 5}")
+            running_loss = 0.0
+
+"""
+T
+E
+S
+T
+"""
+def test(loader):
+    table = wandb.Table(columns=["ground_truth", "prediction"]) if USE_WANDB else None
+    model.eval()
+
+    predictions = []
+    true_labels = []
+
+    cumulative_loss = 0
+
+    for data in loader:  # Iterate in batches over the training/test dataset.
+        data.to(device)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        #y = torch.tensor([1 if x > 0 else 0 for x in data.score]).to(device)
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        cumulative_loss += loss.item()
+        
+        # Use the class with highest probability.
+        pred = out.argmax(dim=1)  
+        
+        # Cache the predictions for calculating metrics
+        predictions += list([x.item() for x in pred])
+        true_labels += list([x.item() for x in data.label])
+        
+        # Log table of predictions to WandB
+        if USE_WANDB:
+            #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
+            
+            for pred, label in zip(pred, torch.squeeze(data.label, -1)):
+                table.add_data(label, pred)
+    
+    # Collate results into a single dictionary
+    test_results = {
+        "accuracy": accuracy_score(true_labels, predictions),
+        "f1-score-weighted": f1_score(true_labels, predictions, average='weighted'),
+        "f1-score-macro": f1_score(true_labels, predictions, average='macro'),
+        "loss": cumulative_loss / len(loader),
+        "table": table,
+        "preds": predictions, 
+        "trues": true_labels 
+    }
+    return test_results
+
+
+
+
+"""
+SWEEP
+"""
+
+def build_dataset(train_batch_size):
+    train_dataset = UserGraphDatasetInMemory(root=ROOT, file_name_out=TRAIN_DATA_PATH, question_ids=[])
+    test_dataset = UserGraphDatasetInMemory(root=ROOT, file_name_out=TEST_DATA_PATH, question_ids=[])
+
+    class_weights = calculate_class_weights(train_dataset).to(device)
+    train_labels = [x.label for x in train_dataset]
+    sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
+    
+    
+    # Dataloaders
+    log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
+    train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=train_batch_size, num_workers=NUM_WORKERS)
+    test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, num_workers=NUM_WORKERS)
+    return train_loader, test_loader
+
+    
+def build_network(channels, layers):
+    model = HeteroGAT(hidden_channels=channels, out_channels=2, num_layers=layers)
+    return model.to(device)
+    
+
+    
+
+def train(config=None):
+    # Initialize a new wandb run
+    with wandb.init(config=config):
+        # If called by wandb.agent, as below,
+        # this config will be set by Sweep Controller
+        config = wandb.config
+
+        train_loader, test_loader = build_dataset(config.batch_size)
+        
+        DROPOUT = config.dropout
+        global model
+        model = build_network(config.hidden_channels, config.num_layers)
+        
+        
+        # Optimizers & Loss function
+        global optimizer
+        optimizer = torch.optim.Adam(model.parameters(), lr=config.initial_lr)
+        global scheduler
+        scheduler = ExponentialLR(optimizer, gamma=GAMMA, verbose=True)
+    
+        # Cross Entropy Loss (with optional class weights)
+        global criterion
+        criterion = torch.nn.CrossEntropyLoss()
+        
+        for epoch in range(config.epochs):
+            train_epoch(train_loader)
+            f1 = test(test_loader)
+            wandb.log({'validation/weighted-f1': f1, "epoch": epoch})    
+
+def test(loader):
+    model.eval()
+
+    predictions = []
+    true_labels = []
+
+    for data in loader:  # Iterate in batches over the training/test dataset.
+        data.to(device)
+        
+        if INCLUDE_ANSWER:
+            post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device)
+        else:
+            post_emb = data.question_emb.to(device)
+
+        out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb)  # Perform a single forward pass.
+
+        loss = criterion(out, torch.squeeze(data.label, -1))  # Compute the loss.
+        
+        # Use the class with highest probability.
+        pred = out.argmax(dim=1)  
+        
+        # Cache the predictions for calculating metrics
+        predictions += list([x.item() for x in pred])
+        true_labels += list([x.item() for x in data.label])
+        
+    
+    return f1_score(true_labels, predictions, average='weighted')
+
+"""
+M
+A
+I
+N
+"""
+if __name__ == '__main__':
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    log.info(f"Proceeding with {device} . .")
+
+    wandb.login()
+    
+    sweep_configuration = sweep_configuration = {
+        'method': 'bayes',
+        'name': 'sweep',
+        'metric': {
+            'goal': 'maximize', 
+            'name': 'validation/weighted-f1'
+        },
+        'parameters': {
+            'batch_size': {'values': [32, 64, 128, 256]},
+            'epochs': {'max': 100, 'min': 5},
+            'initial_lr': {'max': 0.015, 'min': 0.0001},
+            'num_layers': {'values': [1,2,3]},
+            'hidden_channels': {'values': [32, 64, 128, 256]},
+            'dropout': {'max': 0.9, 'min': 0.2}
+        }
+    }
+    
+    sweep_id = wandb.sweep(sweep=sweep_configuration, project=WANDB_PROJECT_NAME)
+    wandb.agent(sweep_id, function=train, count=100)
+        
+
+
diff --git a/embeddings/helper_functions.py b/embeddings/helper_functions.py
index d5ec672..9699079 100644
--- a/embeddings/helper_functions.py
+++ b/embeddings/helper_functions.py
@@ -65,7 +65,8 @@ def log_results_to_wandb(results_map, results_name: str):
     wandb.log({
                 f"{results_name}/loss": results_map["loss"],
                 f"{results_name}/accuracy": results_map["accuracy"],
-                f"{results_name}/f1": results_map["f1-score"],
+                f"{results_name}/f1-macro": results_map["f1-score-macro"],
+                f"{results_name}/f1-weighted": results_map["f1-score-weighted"],
                 f"{results_name}/table": results_map["table"]
             })
             
diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py
index f9e4f17..2275741 100644
--- a/embeddings/hetero_GAT.py
+++ b/embeddings/hetero_GAT.py
@@ -10,7 +10,7 @@ import plotly
 import torch
 from sklearn.metrics import f1_score, accuracy_score
 from torch_geometric.loader import DataLoader
-from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv
+from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv, GraphConv
 from helper_functions import calculate_class_weights, split_test_train_pytorch
 import wandb
 from torch_geometric.utils import to_networkx
@@ -24,7 +24,7 @@ from dataset import UserGraphDataset
 from dataset_in_memory import UserGraphDatasetInMemory
 from Visualize import GraphVisualization
 import helper_functions
-from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL
+from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL, REL_SUBSET
 
 log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
 
@@ -47,6 +47,7 @@ class HeteroGAT(torch.nn.Module):
     """
     def __init__(self, hidden_channels, out_channels, num_layers):
         super().__init__()
+        log.info("MODEL: GAT")
 
         self.convs = torch.nn.ModuleList()
 
@@ -71,12 +72,17 @@ class HeteroGAT(torch.nn.Module):
         self.softmax = torch.nn.Softmax(dim=-1)
 
     def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
+        x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag", "module"]}
+
+        
         for conv in self.convs:
+            break
             x_dict = conv(x_dict, edge_index_dict)
             x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
             x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()}
 
         outs = []
+        
         for x, batch in zip(x_dict.values(), batch_dict.values()):
             if len(x):
                 outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device))
@@ -120,7 +126,7 @@ class HeteroGraphSAGE(torch.nn.Module):
     """
     def __init__(self, hidden_channels, out_channels, num_layers):
         super().__init__()
-
+        log.info("MODEL: GraphSAGE")
         self.convs = torch.nn.ModuleList()
 
         # Create Graph Attentional layers
@@ -173,6 +179,87 @@ class HeteroGraphSAGE(torch.nn.Module):
         out = self.softmax(out)
         return out
         
+"""
+G
+R
+A
+P
+H
+C
+O
+N
+V
+"""
+"""
+G
+A
+T
+"""
+
+class HeteroGraphConv(torch.nn.Module):
+    """
+    Heterogeneous GraphConv model.
+    """
+    def __init__(self, hidden_channels, out_channels, num_layers):
+        super().__init__()
+        log.info("MODEL: GraphConv")
+
+        self.convs = torch.nn.ModuleList()
+
+        # Create Graph Attentional layers
+        for _ in range(num_layers):
+            conv = HeteroConv({
+                ('tag', 'describes', 'question'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('tag', 'describes', 'answer'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('tag', 'describes', 'comment'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('module', 'imported_in', 'question'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('module', 'imported_in', 'answer'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('question', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('answer', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('comment', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('question', 'rev_imported_in', 'module'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+                ('answer', 'rev_imported_in', 'module'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6),
+            }, aggr='sum')
+            self.convs.append(conv)
+
+        self.lin1 = Linear(-1, hidden_channels)
+        self.lin2 = Linear(hidden_channels, out_channels)
+        self.softmax = torch.nn.Softmax(dim=-1)
+
+    def forward(self, x_dict, edge_index_dict, batch_dict, post_emb):
+        x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag", "module"]}
+
+        
+        for conv in self.convs:
+            break
+            x_dict = conv(x_dict, edge_index_dict)
+            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}
+            x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()}
+
+        outs = []
+        
+        for x, batch in zip(x_dict.values(), batch_dict.values()):
+            if len(x):
+                outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device))
+            else:
+                outs.append(torch.zeros(1, x.size(-1)).to(device))
+
+
+        out = torch.cat(outs, dim=1).to(device)
+
+        out = torch.cat([out, post_emb], dim=1).to(device)
+        
+        out = F.dropout(out, p=DROPOUT, training=self.training)
+
+
+        out = self.lin1(out)
+        out = F.leaky_relu(out)
+        
+        out = self.lin2(out)
+        out = F.leaky_relu(out)
+        
+        out = self.softmax(out)
+        return out
         
 """
 T
@@ -291,7 +378,7 @@ if __name__ == '__main__':
         config.initial_lr = START_LR
         config.gamma = GAMMA
         config.batch_size = TRAIN_BATCH_SIZE
-    
+        
 
     # Datasets
     if IN_MEMORY_DATASET:
@@ -300,7 +387,7 @@ if __name__ == '__main__':
     else:
         dataset = UserGraphDataset(root=ROOT, skip_processing=True)
         train_dataset, test_dataset = split_test_train_pytorch(dataset, 0.7)
-
+        
     if CROSS_VALIDATE:
         print(FOLD_FILES)
         folds = [UserGraphDatasetInMemory(root="../data", file_name_out=fold_path, question_ids=[]) for fold_path in FOLD_FILES]
@@ -385,6 +472,13 @@ if __name__ == '__main__':
     if USE_WANDB:
         wandb.log(data_details)
 
+    # Take subset for EXP3        
+    if REL_SUBSET is not None:
+        indices = list(range(int(len(train_dataset)*REL_SUBSET)))
+        train_dataset = torch.utils.data.Subset(train_dataset, indices)
+        log.info(f"Subset contains {len(train_dataset)}")
+
+
     sampler = None
     class_weights = calculate_class_weights(train_dataset).to(device)
     
@@ -392,7 +486,8 @@ if __name__ == '__main__':
     if USE_CLASS_WEIGHTS_SAMPLER:
         train_labels = [x.label for x in train_dataset]
         sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
-
+    
+    
     # Dataloaders
     log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}")
     train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS)
@@ -403,11 +498,12 @@ if __name__ == '__main__':
         model = HeteroGAT(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS)
     elif MODEL == "SAGE":
         model = HeteroGraphSAGE(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS)
+    elif MODEL == "GC":
+        model = HeteroGraphConv(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS)
     else:
         log.error(f"Model does not exist! ({MODEL})")
         exit(1)
         
-    model = HeteroGraphSAGE(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS)
     model.to(device)  # To GPU if available
     
     if WARM_START_FILE is not None:
diff --git a/embeddings/hetero_GAT_constants.py b/embeddings/hetero_GAT_constants.py
index 4b296bb..6c4dded 100644
--- a/embeddings/hetero_GAT_constants.py
+++ b/embeddings/hetero_GAT_constants.py
@@ -11,7 +11,7 @@ USE_CLASS_WEIGHTS_LOSS = False
 # W&B dashboard logging
 USE_WANDB = False
 WANDB_PROJECT_NAME = "heterogeneous-GAT-model"
-WANDB_RUN_NAME = "EXP1-run"  # None for timestamp
+WANDB_RUN_NAME = None  # None for timestamp
 
 # OS
 OS_NAME = "linux"  # "windows" or "linux"
@@ -21,10 +21,11 @@ NUM_WORKERS = 14
 ROOT = "../../../data/lhb1g20"
 TRAIN_DATA_PATH = "../../../../../data/lhb1g20/train-4175-qs.pt"
 TEST_DATA_PATH = "../../../../../data/lhb1g20/test-1790-qs.pt"
-EPOCHS = 10
+REL_SUBSET = None
+EPOCHS = 20  
 START_LR = 0.001
 GAMMA = 0.95
-WARM_START_FILE = "../models/gat_qa_20e_64h_3l.pt"
+WARM_START_FILE = None #"../models/gat_qa_20e_64h_3l.pt"
 
 # (Optional) k-fold cross validation
 CROSS_VALIDATE = False
@@ -32,9 +33,9 @@ FOLD_FILES = ['fold-1-6001-qs.pt', 'fold-2-6001-qs.pt', 'fold-3-6001-qs.pt', 'fo
 PICKLE_PATH_KF = 'q_kf_results.pkl'
 
 # Model architecture 
-MODEL = "GAT"
+MODEL = "GC"
 NUM_LAYERS = 3
 HIDDEN_CHANNELS = 64
-FINAL_MODEL_OUT_PATH = "gat_qa_10e_64h_3l.pt"
-SAVE_CHECKPOINTS = False
-DROPOUT=0.0
+FINAL_MODEL_OUT_PATH = "SAGE_3l_60e_64h.pt"
+SAVE_CHECKPOINTS = True
+DROPOUT=0.3
-- 
GitLab