From c115151869a22cfeddae7c8933f8822518952b2e Mon Sep 17 00:00:00 2001 From: Liam Byrne <lhb1g20@soton.ac.uk> Date: Mon, 1 May 2023 12:59:21 +0100 Subject: [PATCH] updated model configurations --- embeddings/NextTagEmbedding.py | 10 +- embeddings/__pycache__/dataset.cpython-39.pyc | Bin 8383 -> 9121 bytes .../dataset_in_memory.cpython-39.pyc | Bin 4543 -> 4616 bytes .../helper_functions.cpython-39.pyc | Bin 3367 -> 3428 bytes .../hetero_GAT_constants.cpython-39.pyc | Bin 1116 -> 1089 bytes .../static_graph_construction.cpython-39.pyc | Bin 7007 -> 7012 bytes embeddings/gnn_sweep.py | 321 ++++++++++++++++++ embeddings/helper_functions.py | 3 +- embeddings/hetero_GAT.py | 110 +++++- embeddings/hetero_GAT_constants.py | 15 +- 10 files changed, 440 insertions(+), 19 deletions(-) create mode 100644 embeddings/gnn_sweep.py diff --git a/embeddings/NextTagEmbedding.py b/embeddings/NextTagEmbedding.py index 0d0ad72..6d5a857 100644 --- a/embeddings/NextTagEmbedding.py +++ b/embeddings/NextTagEmbedding.py @@ -162,16 +162,18 @@ class NextTagEmbedding(nn.Module): if __name__ == '__main__': tet = NextTagEmbeddingTrainer(context_length=2, emb_size=30, excluded_tags=['python'], database_path="../stackoverflow.db") - tet.from_db() - print(len(tet.post_tags)) - print(len(tet.tag_vocab)) + #tet.from_db() + #print(len(tet.post_tags)) + #print(len(tet.tag_vocab)) #tet = NextTagEmbeddingTrainer(context_length=3, emb_size=50) - #tet.from_files("../data/raw/all_tags.csv", "../data/raw/tag_vocab.csv") + tet.from_files("../all_tags.csv", "../tag_vocab.csv") # assert len(tet.post_tags) == 84187510, "Incorrect number of post tags!" # assert len(tet.tag_vocab) == 63653, "Incorrect vocab size!" + + print(len(tet.post_tags)) tet.train(1000, 1) # tet.to_tensorboard(f"run@{time.strftime('%Y%m%d-%H%M%S')}") diff --git a/embeddings/__pycache__/dataset.cpython-39.pyc b/embeddings/__pycache__/dataset.cpython-39.pyc index fd97a912b458d313f1ee3b398718af41bcc1a946..5a89c09867736d07f5956e8241d511ed95476100 100644 GIT binary patch delta 3046 zcmdn*xX_(9k(ZZ?fq{X+bA?(;&_>>+Oni437#LC+q8L*cqL@<DCm&$ycjZmxlVnH{ znZuYOpQ6yh5+$6<mm)rgF-0*&sf8s<B$Y2kW)5SDa*9d|OOz;7PBlfXg(XU?oq>fR zN<5fBQ)6=pvnf+78v_G_GXn!du>u1FLk&X~Ly=4g;{v7<#w_LqEHw;StSO8OnHDkz zGo&*FGiWmT-Qp}L%1=%$E>2C+WGfP7U|_h#lwW*{JtwocBqg)x7E3{5NyaUfto+Qp zBH78cY!Z_>*_0T?ChM?mVmvqbKAUrW33C?f0=624EcO%@kkKv-u`#ubB^(PlQ&?*l z7c$kd)UXs)bucX8D&ekSS-_LRwve%zaUmlk16U0^SPg3pYf(@McMU5{zLu?qt%f<9 z1x=K>FsPj&jWLBIg|mgDhRKDYdGb{DRjfrC3=9mDy*Q$@d2g}g7o_IhVk^i@&dy1_ z#gdbsm{O$4z`#(X1tPRTgbs+1oqT{pm+{u*cN}JNw>a`rD@x)sQz~w8m*&Akq(}v% z(0H;FrvanX<YG>@dI1Io27Xu&w=sf(cmcygh8U(=##*Kth6Ri@j0+iSnM;^zm=`cF zWGG>&VQywjVeDmcVVJ-e8xq4*%Tmh<iiR526edZAbjB3s8rBpRFpIU9Ifl8Gt(Lu( zrG#w(dktF(n<PUGJBUnUTF6+-QNvKbfTM;rg<X<iA!99P4JXJ77lznBU}L#T*s?fl zxKcPI8B*A?S&HVRaON?ku%$DWuw`+jaP>0Pa)T}Af?1rxoz9p7v6QEl7iuXF$hsO{ z5DB)FuZCd(cMVqx57^Fn{u=%TJSn_291EEk8Nh758a_CiA7VY6F93BcOs!xI7ufMI zzEBM}NNow*0^S<_622_{1p+k;S%N9THB7S^QbZuSXEUUTLiEjMuuBn}%f65aq^X7} zh0~m&mZ`9*P%DKym_bwg7blkj6lgNuVlBxpNzA#$;*_6K84|@Bl39|II$4%SPWu*H zKz?zFXUZ+s;N<+G)FML$1_qG&l#<GVRKHu?{^fb8MWMy1MV={>b9rQBi$SWQxC={D zi%T-|^Wrm8qBs-tipx`r;xkhwxAVwKM{&o)1>&oi^$JQfg^EC#@Rs!Cg+kJjMckll z#gSQ(T2z!@UVMu=CpE7~5|m<9cy%U6@|t-R!IR}JUU=$@PsuFO6eyAbr9mf<3Cx)( z6;b?Ui8+}m@i6BU-(rlv#R20@zR9bhQsm9Rzz`)2aZ+)7Nj_ZhEjEZ?k>6w&J`F~n z$+>*$7Dk}riG_!Si&>7Dg^`1iiIL+k533k6AF}|X5+e&E-+vxv5RZqEhfx3wc^FL? zt0X4tvq?<;$7jwK1ahG6<XSdywjvO_baDj0RX`jA0|Pj#GNv%KaFj4CU@T$CVp_ml z!;r<2!VD=Wni*@EN?2={ni-3{KuI31f}@5Blqo=}YME=8gBeOVCtu?av-gL@6<2Xd zVo^zPd1gt5CQ}i}1GhNRGILV%5_3~mG8V;v90bm~svwpL0|P@Ph+#T8TR<0DR;&?F zb>hNaE`XGR{81DPVp)KQXi)HTx>h8o7J$-#CJQ81VkYwmN-3v;RDf){#hjj6QUNM_ z!9_l!3L^)j6cY<mmB{1*0g=h^f@+LfliLI}m1>xR88q4ae!0REEHvFJ1iSjUI)^A| zO;!+5oGdNms=bo&7ISe)(MqT{87m-V5~}YyKrS$y+%4qF_lrvtZm1^9WJY08A#iCY z!N9<<k{ROctjUJLiqZiL3=FLx$MS#@03#cd&_6b2mcJ}4RXUR=2<dE|EzHBnTI3E& zrt3u#8GAMhik@S%1@S|QK%Ug(C@KQkTMQyV#@}Mj$xpt;o}8aknwwXAiz}tHASW|9 zu_QI+7E5w|L1j_ZWC?Lc#;KE|#ETfUCZ83Tj?7|CXGmdQ#F)a;3o1WB!NOYQ2={<n zYDsd20wm2AE2QNYDI_Yuk~)|h;9#hbT9%konh1&(xW1~;$=`*&CPzr{)VqLE3=7yB zMQ$LMdVmNokSOCVw&K#H;?xpN7IdGmr4;2C<RxlC1FQ+8rvXHO!yQb3BD^S+fq`KG zC=B^PVaUS9%=DLyh2;+ms{mt_HmD4lJVi)&a=$po<a-j5lPd(>9jZ*=u7vmw<{yav zaCoRF8)SJ7C@wy-i%zx^QqnBS1qtPY2(Wv=1lUoa^u7rsZpy&GP~{4W%=&mxDCybc z<R>TQ6x-?9fU5W8{M-V&qC!wmutc#J<QIb~z$jjrgW{9&Q!1nQ;M@{W2^u8;=a&|P zOF&Ilq`-qGwW3;(14=*y$YDH?R10$6Ed@}SC<k`X8l<3uCqFL6DnpQ)S;0y+zmmvc zWUSa6Eak??Rs}LIWb%CJ494w~Ib<UEm>C!tSYg$%)8w<_lD;L(pz<VzIh(0SEuA5l zp@s=m-bgZl%99k<6t-3-NrnZiDeRy&kS2#;RTw-z(DGnNVtVmp0TIJ`NOhrLt55_= zOht^KvVa973$8kf3_zjGWM`wvh7=H>`~WTtZ?R;8as_K~K~81~a(aqlFG)-<j!#L0 z1Pv&ZiooF`o(4+w@nG4s{GxbeScsnng%+p*!oY|~XgrK8j8z(wCkTl%mn5c7z9S($ zIYFj?yBZWyjUb|J@@*ME#;cRHWNo-5ZgIrNg9^#`_{ll4(g}GW1=1k@a1<2f7o--I zR2Hd&#B@P~9*6){V?~J|7Pv^v1hJSvgbj!Ql_5n4yFl5Y2vnvNwSi>2Cf}7cbprW} zgHeQwiIIz$iIIzygPB8&Lzjb@gI9>1kCTs^k4K1`PlS)FNC@nZ`1st!%)I#cDlR=e z{glL##4@AFesUHPMW70$ipLRL^D5+}Wae5G#e*`$GPxW!bC7kGlO^QkC2z5mWag&c zVk=0@OGzw-L<y)Mzctxk{(=>#j7kUFR5Sr(9aBni5!g=|ARcQ$QD$DrEw+mIl+5Ik hTdc(yi3O?P>a@sba->4AunQL_A0ro|5FGL`0szKl@>Ku; delta 2610 zcmZ4JzTc5Ik(ZZ?0R%YyCFgJCUCP9_hk=11l_82Tg&~S5MS1c8rhZkv6!AHXDe@@_ zEi6&OsSuuGic$+plt?=R3qzD>FoUMb<|)jkOuozv3=GZ;3=G903=9l43|S0C0ws(K zm{J%QG6plGGXyheGWq=y)YH>XNi0d!FDS}SPAx7@P0?g75@lduxW$xTe2YCNv$!NB zv#3aI@<n0c$#!f?j9ilw*)}mAnk>TZoRY$_kkN%9Hl~)bgmnR13TqAHLZ(`l8kVBG z4u%EnB^)&@3pi8Q7BV(7E@Wh60IOjKt6{BSEpjX2s9}Z4*Rs{H)i7tXpouaUx=lXL zzLWbFC)|0HTREZ_xh6m0kekfQsl^yI*@DxIF=}!?rwpUv<Q`4~M!v~=Io;|(e&&G% zfDjV{LkYtIMv&JT8A_NIFfC-LWvpRXz`T$lhN+gRmKiF>Qo^u+C55SmA&WVSwT7jJ zQGy|jNsIyPdW0^v8m5JewX7I=*bsW^p_<rhSQl_CWGLaRVQpqiVeVyeVVJ-e+f>U2 z)y!DKki}KQmc^aUn8H%Smc^6Kn8Mo29K&48UdvI-S;D)3uZBH^O_HI814O2QEarmh zg*rBc9g9QjA^zj9VapO&$XLsbs-=b@iyz@uh*N90K@N3ch<y^nQp;J(Q_EY+SHhbm zSi_gXmd!GOvB)KbBabPCEuFE1H%lmmvzMusAMALLlUY)@;0_n46|Ci~2fJJV<mMVd z5D9j<5Y!o<pck&;UC3A~3>AU;rG_C(7{g)03q(@5Yq%CNF*4ME*gQ2ta5irZKZsqz zyFj!?xI`>Ve1Sv_LzZMp244;HY=#uD|7SDIWm(7sQdYy9!eP!(%RGUxP_3{dg)5jr zQ{WdTmjV=22?anhL1tdM0w`1H>FHI8xRvH5mt^MW*{T-nD1>C@rrN3&Yck$qEy*uQ z%n6yik6X^CN--e6xI`f}Hz_qG1*AEtG&3h9wMapuur#$8q*=2_pMik^q%EbSvLMy( z77t7&J~cP#7FS|kad~PHh&|bXM^;-?>K1cOe)=tz%)GSxTP!7+xv51CpiIe{lA2VS zev2bMJ})shH9r29^yGXYX{#b`kR(@TNor9^X+chE@h#?@)Vv}|kSH(QrlkCo%3FMJ zZb@cIPU_?}JUSfg1^LA#@tG-;uk)CD-C`+7EXlaVl9ivCcZ(OE@8VN3i!?=wWI#qT zl^7P;gBZ-2DHXSvN{piT%Mx=kQ{rJ}7vEwo$}gWB#;d`4i!r{)WpV?rgjy6g)R!Q4 zMDam6rNya5@tG;NSc{YMi&Bf+CNJXEV04{)npfT52UHj`b1`u+^DuKTN-=XVbAf3o zCJts6Mm9z!Mz+5^tYS<YjAAStOgu~=H6n}(j4X^?|9O}}Vq9P`9!3yUo2<?k#OBYy zz)++;xt>pg6~vr4c@CeIOArGCLpwtnV+vymQws;Uq+kY>Y|V_dOeIV;OwEi%b|uUU zSV|bOSQoI>Fl4dUFr~0aGAsmTN9G#lV1^Qo$!h#zF7A*><|-~pEGj83&n(H%WGXTM zg%n3xW=<+7&Q>xO#e>|S0wPpFgfRmH!&DH%gn@yfN@Maw0g=g9_(epjgfPlx{mF0n zRqHv6G(pO>Km^DWMIiSVnS;2|AR-1t#DZ+)bgf8EEdVD#7I4&o?Y_kcatt^H-C_Zy z^CGZgGC|rvF1p2>o?0>mRC9pqBSr~E4n`>^7N#nZ(Bjl0_sP=)OeecbSTXxmrA#&y zmrwvDe$Nyg1*iO!N>COl$w}2w@Gs9xEdr_cOqt9ttT;JVj(xJdpq+sx<1Oamk|J2R zFjhb;2fF}dc@Zc$mVlfB4i1OOijq2$ZwiXkS8-{=bE+mwQ7A}B0Eh@=U|@(6fF-?n zXoAvYy~R|Ja*MAhH8CZ=xG)ElHj65Yz+oc+3LEz1{G8I<ykbpeNXX=Y1mW4TI6fr} z#bpW%3`M>S3=9iF;lc+>nT%{qLjTy9S^l!Ha4}ZtfE>EHONfV&;}#nvL?+J`PGnrU z`LFOfM#~BY28NKLN>KXXC@KOuyBI`(UBaA`pIlS|@)X+chAa$-qp$}N`U{DR7& z>d72pj*Ocodx;e>Dox%iCe073v}&1Zn1UHJnf<DAKrtWz33;R-*HH)tWv9vig$yTG zi}OS{f-Gi1vp9+^r6|83FA)|3jUX4*g9vbJfC+GvgfK8L>_zf7Gt*x-7M4FOtQ?d7 zi(7B@mx^X&tlYd>+KrK|8l*jV@^_gG#$%HcWFsb<$QN+efCL*rMAPK;@_vlBCjXPS zQ4%WxW$;^E@$tF&DWy4}(u^lQzOXbg2b53Z<BJL=Cn`wTYJ>FYfCx|tSd<E4fm3}h zh{XgVtU&}Qa*7Z(f!tiwH2Hvnsf_{y0|N)62p1D07c&zh7b^!dhZu(^2Qvq+5C<P8 zA2%N-pGc7qDEc|#<8u=;^Wx*HIG}lNvbmy#!YvM1>Z;;#1eeqbc`2E>R=3z7h1leF z#T+&>kRvQ6|5TKhDDnas##WG+my%cv33N~vjhbwzbU~`f1EdgS9mM}kDaA!#ujEXY gR+i#+0m-_82o-RfD_GfD*prKskCBT}2oCue0Xe0iuK)l5 diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-39.pyc index b144eff4b8fc104d87e4e3c6da288641d1083a7f..b0f15c4568700226cde4638c04bf1bfcf832dee9 100644 GIT binary patch delta 1172 zcmdn5+@Zpk$ji&cz`(%Zxk4>Pgl{6B4C9@N+H+iQqzFZcq>82pH#0_wrHZA9WHS|Q zOBHHnU}VTsN)<~H&1NdPk;<1Unj+TB*vuFuo+^<lm@1ws(#+J%7$up)7|ft4zL|}2 zE{iHV0|SFI0|P^`0s{j>3S$aW3quLR0>%=CET#p_H4IrSDa;F*ChuXru2`hTz`&r% zTBOLpz;KH_C$qRDC9|kViGhKkNEt*(O}1s1*nFDJfQeCO@*j>+M%~FloEP*=KypGL zJsbr^`30#(C6z_0ATf0ip#vgxL4+QNu$>&h6;;m&@-h^&FfcH{-RQ%>z)-@_%uvf% z!nlB`gfWYG0ZR=-3gbeig^Vr?u^P2ZC9DhBQkZI(7BbZ`*Dx1J)i5q#FX5<RUci~c zypXY(aUmlkL!nd&R}Et`C=giuz`iQd2U%_a@*nptX3yf(60k3ev_ZlM9~yzU#vrdS zC#IwnnSi*aATA3iG&EU?q(CAd*A`iWST>U*xsACb7<m|37{wTiJSI=$Ue9PS*^MW; z-U6hJwJ0$!Jynyb2;{F@jCn=eAYI-df(I;LT#%Dl;tq-(P%tns3NaQ1GcYhDGeH9f zM6oh3FmQkbLCT`Q0R-|~EfXk+YM4MCE@3WVsbQ*N0L6YSLkVjQLwz$SymVriYME<U zO4va0Uc;QiEY2{SA%&%cy@sioaW+E=>s;0vmS*NK21bTLjTDAp22D1<Dh@q8{glL# zMExq(wEUbD-6Bx16&W%xF#KZEEiBe6DAD97@&GxXttc@sB|rBTdvQi-T3SwO5y&TC z58Rr(lvh9!6b72Cx4^~}BZZyU<deKUTpWyCOni(2j75=?<@sbKrh$D{!r07M%TNOH z3u6ic$X}C#_>>btk<L=W(!l_7_FRacL1L^WETFKM%>Wf+FJVk!gQ-hlpUd3A(9B%R zR00y`fT;zE*Dy9S*D}{I)i9?q1v6-JPQJ#M%@{e^i9cJnimjw5F*8rMD1?E5p^CL6 zwYWqV=3|y1fBz7OH^GU+V)9LXcgCp6iUN9COdQNIj6zHzj3P`Dj1r7R37{b2h>y=r z%*>0Aui}CRYnhS$<XnMRF{xYZIr-`7nR)4uq!<cH6-JXE2za~4fS90YhonGISc5~g zC=w(AGPlSM#EJ&lz?zv?P+D?}Ejc4UGdZ;=2qXcHhlt6!f@*qNpwz?20fszG9E?0n ZAShrU#L2<T!O6kT!N>xZ1F7I)1OTSz`+NWZ delta 1098 zcmeBB*{{r($ji&cz`(%J7Q>x#iFYEO4CA$l+H>qLrHZ5oH8Vzuri!KrXEPOTN)>8m zU}VTsNEJ;H$!03Ll**SXk|Ns7*vuFumMWeqm?{QlOQbLcGiZu!=3$)6GFgt5Q$U1) zfgy!4g{g(1gkb?=3iIT1tk+e43F_(TrzDmn>K7E{C#Mz{r>1B!6)7?>Fx+C#$t*5O z$t)^Tne50Wy!jrR0TZL{WPZ+2M!m_YoEP*=85kIfgg`1e3X1XzQj1C|i&Q~k>L5ZF zMCgGCeGp+cIhiY}9^^e3W@cbuU}s=paAsg&D0X6CU?^c|W~gN-VO+qJ!nlypg&|g> zma&9+0ZR%~4dX(lTBaJNB7qu)1*|1(HB1ZGQ<xVrHZv|{WMn85DB-AKXa@O~#ji*k zWSaqqFl1n0(B!<uT%1}0_E3=ySlAdOX999Ib7D$L5!e@IATA5YznaWNAV(G1fK=In zJU)2_w*nU*BM&1BqZnh6=j21&>lqCv7x5&^Sb}u17A5ATr)n}4`GR<ic}3hHoj#L4 z^2pYQFfcGAGr~N~z`(%5z`(!(5(N1<0_JB%ke_S7ekoxtVX0xPVQ6M*W~^l>VFmfi zg&|fahN+gRmbrut6n`~LDa_&wvl&uYO4w@{n;B;_q_ECqtzm9v4r5?sC{#;f2xib^ z^ZUh-0gKuy*0lVb6x||Fh!hzyFfjaL(=9C4D=5+ADDnh3o2@7@FC{<s7JG3<X<AxN zYLN%X6Cmf`;!Mpe%}p&zEJ-b51qp&Yt;uo=Y+^Aath^`x;q|d)W8`AuV-#R4iiUa+ z?r)H<L3%+kUOWZnYlae#KN(XPN|;KRC%5w{C$cPHO<^fv?qC4<X)Y5eZc9L7tR>74 ze?rCBQrJ@1Vd_#i=7MxE*D{rW#5rMVLE<%x&CIpTHB2?kX-vTknp~5)`Lh`#CYSSP z>lB4DFfdfHl@ujr=IIuNfjF!ssl_F_n(RfOAc2GcIB{4`<`-~hjF}uRpeM)3#w^1q z#3aHf!X&{c!B~_C3O$ba_}s+Iy!iOZ>jh%PByX|j<fo@+=A}bYT{tMzj3+Azdb`Jh zn4pk=q&l!Ypg;x(8Q82Mdys*wnRx}JCAZj;Gx9T&Q;UK@+Q2aq1!8f7Q{@>!K1Pkn Pw*}>R`8XJ1kcSZfFU<3e diff --git a/embeddings/__pycache__/helper_functions.cpython-39.pyc b/embeddings/__pycache__/helper_functions.cpython-39.pyc index c703478f4863918e0f7707354d90e4d738152c74..0235bc1e59c4ce3b5175b5dfeb5eb5f57eba5468 100644 GIT binary patch delta 973 zcmZ23^+bv<k(ZZ?fq{V`c8hLGB+o`ZDMrRylQkK?G4fAVV5(*m+&qWLkC9Pq@@r;G zM#0IdEGcY43=9lKQj_~xB<;9COg;t%hFk1;@p<_vsl`$3dGV<!>8Zs<$_xw)Q35H6 zC5gqUCGm+xC7Ef7$t6YnAT3}*VDd8-Q%3E{GOY2MA`A=+n#@JAAW1opd2Gq~d1;yH zMLHl6uyQaVFnI~9HlyC;3#`3T1|W6Z$vKI|#qoKGxv9mPED!_aCzrEna)Z@_34zJ0 z*%TNpCZAzb;o@arU~pz&U?>im{EN+As)lhkLkj0yrW%G6E=h(=)`g6W3@Hr344T}N zUD;LYIrY;Fb#oJwi}I`Z(hPNrlk<yGp?n^YRC#J<dPYfVN|hj77B0l5Uy_)VlUk$+ z@){G=Wngv`cTs9_X--LTd~RYvktLcBLB0~n$xn|jN}ar$U04JZybO#ijC?ToyU1qp z1NO-}wpcWB6(^RZ#^>gzq~`d6e8~igTE-%`$@4kNl<Yx9vzC@*<`fr6g4jGoiFqmc zx$(sXIhiGzY(-L&%{g6B0<N7?RyYV`5C_;G0Y)As3C1Gd$=f*(7=Xf7lMx)xE17OF z=^5N&tU!)lkPTcmIhn;J$@#ejb}^Heb7k>EG~HrN&M(a?De{}Fz%3#I5-Sn_`6dvQ z5J8Nf$sXK3^`W5PV!Xv%T$)=1@{T6sEjF;mVu(F(J#k<?Ak{2I@*sOzGV{{%i@+Xc zD@n~O&M(qrDv|=p@In#{B$^=RXfhRrgA51(5g;>fi6<xKB$rO!#w8RU3GzJ;$oEWK zjC_njj6#fDOhwTkyTosC<fWFB=NDyH6mf#|g6$J6$uCOIh)++=&rK~U%1qXS1UM)r zAc?>U<c!HRJjPBSrxpo<xXi^RMMc^mGxR`&1&FW#5pE#D7es)pgm@bqUSL9C@<$#+ KRz5x!4t4<ch~R1f delta 925 zcmaDNwOooXk(ZZ?fq{XcLXkh^H}^(9DMrSq$(oGc82KhEFjX@OY@Wm9$H*u;`8Bg8 zqrhZUmJ~KY1_p*A$;tgJl6D}SMZ63Q3{mWP@p<_vsl^~hYD#))agh>8SRf^_B(XTP zBtEgIBr`2Bxul2>q#I1|PkzQ?%BVS6hBaPOn1O*oletI+Bq__lz;KH#IX^EgGrdR) zBm!0rCio{WVbx~TnS6n@S4t0L0e5mvVsUYNUSe))u_g<|0J+KKY?|C)^<aX3@@h5( zM$^e>*i^W<85kIx85kIf-6sEHv**cVS;)x9kirnmpvgHom|ex4SwGFNiX+WXw>UY! zD7A`Bza%j!C$&fe<US^-@nCiocTs9_X--LTd~RYvkr}!(U$e^z@qogWk%f^D0{;|Q zPL}1EtYd|y7UW&7;>5Dl_}u)I)Er-s>zKe=i<~Fl;V4tG0jXduEy>I&E|LJTd5RMA zQu1@-iwklxOElSvBqx_}x>|$%3MTkLiX_0Y@g=FnCGjOiiJ5uv1(hZFMada~APYFa z76>r%Fi9{Lc}<?kEi##dYmEU&yCx$z?p8A0V$w6X#aMwHWgvsOY;rP-OOo?*3+$pN zKjF$UhiJORnw(#nS5o8+3LeH=JP<o`5|dJMiXldToK@t<z`zg#3IYGgt=v8VK@io< z#ihA0oorxp(DlTE^?+2f6v=^{!IGJmmR|%88@7_vyyE;KO{OBr$%;I};$R0c6@`Gb zfP7m7viX+y<ab=+;SnI;gJXt?i;<5}h*5}<i>U}CswsYpBQLe2JijQrq6ic|V8f$C z!GRW^o|>PVT2hpmtOtn*Pz*uRgFPr6PTt95>;!UGkpReX%*7=|MVcV{bwGqEh_D0^ eptM)y1!94$goF#oxeylrWIbL(R$e|94t4-`Y0h{6 diff --git a/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc b/embeddings/__pycache__/hetero_GAT_constants.cpython-39.pyc index dc4dbeba04cf288cfd9bf78b062a95b945be411b..4fb58f303d20a0568ce0f968a6241a656623153b 100644 GIT binary patch delta 213 zcmcb^agc*Ik(ZZ?fq{YHV5n(|+eBVj#)^sBHjXS&d?{Q}{3+Z~0x3LEf+;LfLMcp9 z!s!etyo(s4L{j*oL{s>q#8L#J#8U*LBvOQ;BvTlpq*8>Vq*GXe88k&EE>U3gn|Q8X zOoV}fA^p$FkC%Lt0_@Y%z~G(z<YvYa9VU0@m&^<d3@;rR7#OOAf*svm<BfCT%?wiG z%}g@%3QE$A!QisJrp#n@rf1w-L9RaW!J$sUt|60onAfqXFfcF_sZ2h=oWf`}S(Qa? IvL}l$01HPs6aWAK delta 256 zcmX@eafgF9k(ZZ?fq{Wx(-O55(}}#Yj3pDbZEU%s_)@r|_)~bI1X6gT1XGxzgwh#O z_!con38(N!iKGZbiKYlfiKPfdiKhrhNu)4ENv4QINu`JeGiZuUT%o{K#o-zeV5nPE znm6%ey&1@c^gk;<Uh+)}uuo3|gLn2-a(a6Dx%nxnImP<vi6!xciSb4Tsqtne8S%zB zdIcqu;}}Z}nB5&iUNSQ<FuZhNU|^^cfhsgaR+tX*00{eON>BdB_-yhArnPL!3=9lK b%99r}r!bmK=3r41WM*OHU}RxtVT3^d?~Fq* diff --git a/embeddings/__pycache__/static_graph_construction.cpython-39.pyc b/embeddings/__pycache__/static_graph_construction.cpython-39.pyc index 2d145e62ef1474e70b94717c8d6774b17b1e19f1..1ef4ec5d29e091fa7b326d0c1910abf0eddefc42 100644 GIT binary patch delta 1419 zcmca__QZ@gk(ZZ?fq{X+bA?*UoQ=Fond{jY7#JKF7#NE8FfcHrGDI<^Fhnt>GH0=* zFs3l2FsHDju=cV<v8J+Rv81r2u%~dOaP~4sv8Qr0GemKwaHVj!Fhp^sa%b_R@@Da+ z@XTRK;Z5OdVTt00i}BB4N)bpAY+;EKfQt#uVM-BB5ouxZixNx`O%ZEhh!RQ>PmyS0 zh!Rc}Nfk{Mkz`1boWqnNl_K535+#-@nkp{IkRmgOF-0~-u7xE^B2_X)ehy=bLW*Jw zOO#ZKQZR$2@-6Yn2}0tNSy`9XCo_V)3B@2b69WT-GsvHQ3=9mJ3^fe#5R$QksTss$ zPGPKJh-XO$v01?+8<=DVlN?}@6HIcYFx4=`bEh!ZFvRnuFa$GbviQB^U|?Vfc`3rc zz@W){i#vYu1=i<Uj71=Rkt72H1B8%bU|_h#5g(tKmst`YuMBdH1jvajn|0VD8QCmA z3N0shb4X3z%O){-6-Pd!*=8wDK}JUN$vRwrVe<Oi5>{3qr8Xej_={{oYy}Wu2O^9? zgguA=MQf2Gh@}G}^gx8u<VJ2`Gf+Sjxj?jo-53m#a|ID@Ai^C)cz|5Xm|5fn;&?MK zFibwcEkAiVKmX((-1&^=lQVdJGg?ee;ayrE0#X4A<RWVj3+ygFZ0@?nT9lZVo?3*> zdHNu&P9Q=PL>PexEszuWAi)>`62RibqA*a{gIx)C;6A?fZjm6ptPod#O#}teE#Z>H z^!U`=q|}s@%)IpY-29Z(oLhqFsU`6!5}Hg9OZ_J==3k_XY!t{zMPQ@Ea`RJ4b5hYw zn;b77#~3iVP9T93l+=pMK_NF=NNO^p;6k-1aF|v=k^?wlfYL-!3<CqhM3Dbt85kI< ztR^Q2sZD+-D5?hv$(R5C|NmdbWM`wve2cZXASbh=2<#nDv}>|JoC&tbXR@J?7^CB4 z521Hj0U*ya6@jcP3IiDowhK(aEt~-okDZ(#>??$9Ru)Luck+H=F-C{U*M!{}%{R-5 zd}E6A2Z^hJ2$10~LHS9O@fK@JNo7H*CUX%;`W9zOYH@N=W>RW#Q7*^@5aFUCkghy* zfqaNSz~o$U3C4iQZQ^?JK_G?UAR-<_B!P$&5CL*vk@Mus;_JB0Kzwj|jhNgcajssv zNE##}10rNW1Uo2<xpMLoQy|e>lmHS1m2*YOAQq?qECOd0Q;<rgf|OgFE{P?HZbgZ? zsYRfyaf>Y^HLp0os0gQf<Uj_2f}n^SoWVFje22-qB(?cKIhun}h*5x%jgdoi@>|KH UY%L&J=gB*y<QeTI-;int03ie~Z2$lO delta 1486 zcmaE2cHfLQk(ZZ?fq{YHPXc$!nvJ|mnd_Mu7#JKF7#NDDFfcHrGDI<^Fhnt>Fs3lI zFhnt@vShKQvSqQSFwbF1VM$?aVTs~^i?PjNN?}joXkm%sgo|;`VM^gj;cj7x;!5F3 z;ca1v;!fd9;csDx;z{LA<xAz2WJnR1!;~VJBGkeX#h=Qjk}4p{kRm*XF-0Urw1p)~ zFjXi;Yz||Jc#1>|OO$YmWH5uK)Z|(g$@*kQki(&vg@J*Aje&u|8RYUZ1_p)_hGvFZ z#uCOXu-j9ZL9Xp(sbwl*&SFYo1GzJWvzNJ+Ig_D=A)W<HvX-zlgLv#IEHw=A9O)o7 zCz#{{liXmE2Tby&OMs;LQn+du;`vj!YZ&4MQW%06G<p0cPh(lC9P$zrteS#FAgYL+ zfq~%`S5AInN_=j9N@`9K$f=XJv%X?xDH56N$RRCyOSmL4Jw7!zDK#Y}GcO&YY_g1i zgjSIx0|NtukOCRQ5g(tKmst`YuL$z09LWDnY>YyT0!*7vvqdtpS%Z|@OxEI%Vl6Ud zU|^UW!;#Nux%mu-AS0vI<a?ZdVJaSSNz~hd6x%Z}Flh1@fg-&~0VDv@R%8rfIf4i$ z5CIC<A{`J*4@9_t2u%=S2qLsVge!<}gBT_PVuykVcM#zLB0NEa7swBcnMFPz&g3R; zNwy#mD|qsB?j}a7$x=MO8LcOa@h+_o0||qoxX2E~0y~lqn<H<r7A5ATrxsyzt3F68 z!mUOiF37E#d_`avMu7ydxUeWf86*RCq#uao4<Z6UL?DQOx;d8bkW@6t!>ka8fdd2_ zgl3?S_vDvm44j<GzsM3<5y*Z>IGZytFhq%@r<Q=CyEG>i5&xP@5cdU54i-q@1lwx` zGO<8NYVtLKg=#V2P^^F?3vdDfrH!IE1_p+SAos;HFfdeEO|EAVpX?*VKl!krD5L%4 zn}X7sph(taffx=p7Ua8Ij3wwMIZu`qdZ!QarY2Jn$hM*gkhx&1zy#dB86a`6eGXvz z76|Z9P7qc?u_+6r1{8$17<18Wa+<ta*qza8^Jn32O!eTvRRgI8S^bg;6t;}FSW`+W z3sN<ii$K!1I8#!KlZ!HwQj3dnLGlpcq9Ty4JamD4hycj6Tilts1^GoKsVVW9c`&UY ziK1ed*2x*-it-^KUL=S}1QE#~A_YW%!mh}5@(S^F+?F6dIC)1+PLw#8AW|d^5|IHB zvLM0&6hTY{DYrOX5=#=@iV|~Ei$M9~7F$SaUU7a=5l&k`Il2hs@gi`R;{=)JG<m0_ iHir-c0|N&Whxp`Al1JH^LE^5H_e#k#nohne)eZmwk}hrl diff --git a/embeddings/gnn_sweep.py b/embeddings/gnn_sweep.py new file mode 100644 index 0000000..4d5d830 --- /dev/null +++ b/embeddings/gnn_sweep.py @@ -0,0 +1,321 @@ +import json +import logging +import os +import string +import time + +import networkx as nx +import pandas as pd +import plotly +import torch +from sklearn.metrics import f1_score, accuracy_score +from torch_geometric.loader import DataLoader +from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv +from helper_functions import calculate_class_weights, split_test_train_pytorch +import wandb +from torch_geometric.utils import to_networkx +import torch.nn.functional as F +from sklearn.model_selection import KFold +from torch.optim.lr_scheduler import ExponentialLR +import pickle + +from custom_logger import setup_custom_logger +from dataset import UserGraphDataset +from dataset_in_memory import UserGraphDatasetInMemory +from Visualize import GraphVisualization +import helper_functions +from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL, REL_SUBSET + +log = setup_custom_logger("heterogenous_GAT_model", logging.INFO) + +if OS_NAME == "linux": + torch.multiprocessing.set_sharing_strategy('file_system') + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1])) + + +""" +G +A +T +""" + +class HeteroGAT(torch.nn.Module): + """ + Heterogeneous Graph Attentional Network (GAT) model. + """ + def __init__(self, hidden_channels, out_channels, num_layers): + super().__init__() + log.info("MODEL: GAT") + + self.convs = torch.nn.ModuleList() + + # Create Graph Attentional layers + for _ in range(num_layers): + conv = HeteroConv({ + ('tag', 'describes', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('tag', 'describes', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('tag', 'describes', 'comment'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('module', 'imported_in', 'question'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('module', 'imported_in', 'answer'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('question', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('answer', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('comment', 'rev_describes', 'tag'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('question', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('answer', 'rev_imported_in', 'module'): GATConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + }, aggr='sum') + self.convs.append(conv) + + self.lin1 = Linear(-1, hidden_channels) + self.lin2 = Linear(hidden_channels, out_channels) + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): + x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag"]} + + + for conv in self.convs: + break + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} + x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()} + + outs = [] + + for x, batch in zip(x_dict.values(), batch_dict.values()): + if len(x): + outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device)) + else: + outs.append(torch.zeros(1, x.size(-1)).to(device)) + + + out = torch.cat(outs, dim=1).to(device) + + out = torch.cat([out, post_emb], dim=1).to(device) + + out = F.dropout(out, p=DROPOUT, training=self.training) + + + out = self.lin1(out) + out = F.leaky_relu(out) + + out = self.lin2(out) + out = F.leaky_relu(out) + + out = self.softmax(out) + return out + + +""" +T +R +A +I +N +""" +def train_epoch(train_loader): + running_loss = 0.0 + model.train() + + for i, data in enumerate(train_loader): # Iterate in batches over the training dataset. + data.to(device) + + optimizer.zero_grad() # Clear gradients. + + if INCLUDE_ANSWER: + # Concatenate question and answer embeddings to form post embeddings + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + # Use only question embeddings as post embedding + post_emb = data.question_emb.to(device) + post_emb.requires_grad = True + + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. + + #y = torch.tensor([1 if x > 0 else 0 for x in data.score]).to(device) + loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. + loss.backward() # Derive gradients. + optimizer.step() # Update parameters based on gradients. + + running_loss += loss.item() + if i % 5 == 0: + log.info(f"[{i + 1}] Loss: {running_loss / 5}") + running_loss = 0.0 + +""" +T +E +S +T +""" +def test(loader): + table = wandb.Table(columns=["ground_truth", "prediction"]) if USE_WANDB else None + model.eval() + + predictions = [] + true_labels = [] + + cumulative_loss = 0 + + for data in loader: # Iterate in batches over the training/test dataset. + data.to(device) + + if INCLUDE_ANSWER: + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + post_emb = data.question_emb.to(device) + + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. + + #y = torch.tensor([1 if x > 0 else 0 for x in data.score]).to(device) + loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. + cumulative_loss += loss.item() + + # Use the class with highest probability. + pred = out.argmax(dim=1) + + # Cache the predictions for calculating metrics + predictions += list([x.item() for x in pred]) + true_labels += list([x.item() for x in data.label]) + + # Log table of predictions to WandB + if USE_WANDB: + #graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data))) + + for pred, label in zip(pred, torch.squeeze(data.label, -1)): + table.add_data(label, pred) + + # Collate results into a single dictionary + test_results = { + "accuracy": accuracy_score(true_labels, predictions), + "f1-score-weighted": f1_score(true_labels, predictions, average='weighted'), + "f1-score-macro": f1_score(true_labels, predictions, average='macro'), + "loss": cumulative_loss / len(loader), + "table": table, + "preds": predictions, + "trues": true_labels + } + return test_results + + + + +""" +SWEEP +""" + +def build_dataset(train_batch_size): + train_dataset = UserGraphDatasetInMemory(root=ROOT, file_name_out=TRAIN_DATA_PATH, question_ids=[]) + test_dataset = UserGraphDatasetInMemory(root=ROOT, file_name_out=TEST_DATA_PATH, question_ids=[]) + + class_weights = calculate_class_weights(train_dataset).to(device) + train_labels = [x.label for x in train_dataset] + sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) + + + # Dataloaders + log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}") + train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=train_batch_size, num_workers=NUM_WORKERS) + test_loader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, num_workers=NUM_WORKERS) + return train_loader, test_loader + + +def build_network(channels, layers): + model = HeteroGAT(hidden_channels=channels, out_channels=2, num_layers=layers) + return model.to(device) + + + + +def train(config=None): + # Initialize a new wandb run + with wandb.init(config=config): + # If called by wandb.agent, as below, + # this config will be set by Sweep Controller + config = wandb.config + + train_loader, test_loader = build_dataset(config.batch_size) + + DROPOUT = config.dropout + global model + model = build_network(config.hidden_channels, config.num_layers) + + + # Optimizers & Loss function + global optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=config.initial_lr) + global scheduler + scheduler = ExponentialLR(optimizer, gamma=GAMMA, verbose=True) + + # Cross Entropy Loss (with optional class weights) + global criterion + criterion = torch.nn.CrossEntropyLoss() + + for epoch in range(config.epochs): + train_epoch(train_loader) + f1 = test(test_loader) + wandb.log({'validation/weighted-f1': f1, "epoch": epoch}) + +def test(loader): + model.eval() + + predictions = [] + true_labels = [] + + for data in loader: # Iterate in batches over the training/test dataset. + data.to(device) + + if INCLUDE_ANSWER: + post_emb = torch.cat([data.question_emb, data.answer_emb], dim=1).to(device) + else: + post_emb = data.question_emb.to(device) + + out = model(data.x_dict, data.edge_index_dict, data.batch_dict, post_emb) # Perform a single forward pass. + + loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss. + + # Use the class with highest probability. + pred = out.argmax(dim=1) + + # Cache the predictions for calculating metrics + predictions += list([x.item() for x in pred]) + true_labels += list([x.item() for x in data.label]) + + + return f1_score(true_labels, predictions, average='weighted') + +""" +M +A +I +N +""" +if __name__ == '__main__': + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log.info(f"Proceeding with {device} . .") + + wandb.login() + + sweep_configuration = sweep_configuration = { + 'method': 'bayes', + 'name': 'sweep', + 'metric': { + 'goal': 'maximize', + 'name': 'validation/weighted-f1' + }, + 'parameters': { + 'batch_size': {'values': [32, 64, 128, 256]}, + 'epochs': {'max': 100, 'min': 5}, + 'initial_lr': {'max': 0.015, 'min': 0.0001}, + 'num_layers': {'values': [1,2,3]}, + 'hidden_channels': {'values': [32, 64, 128, 256]}, + 'dropout': {'max': 0.9, 'min': 0.2} + } + } + + sweep_id = wandb.sweep(sweep=sweep_configuration, project=WANDB_PROJECT_NAME) + wandb.agent(sweep_id, function=train, count=100) + + + diff --git a/embeddings/helper_functions.py b/embeddings/helper_functions.py index d5ec672..9699079 100644 --- a/embeddings/helper_functions.py +++ b/embeddings/helper_functions.py @@ -65,7 +65,8 @@ def log_results_to_wandb(results_map, results_name: str): wandb.log({ f"{results_name}/loss": results_map["loss"], f"{results_name}/accuracy": results_map["accuracy"], - f"{results_name}/f1": results_map["f1-score"], + f"{results_name}/f1-macro": results_map["f1-score-macro"], + f"{results_name}/f1-weighted": results_map["f1-score-weighted"], f"{results_name}/table": results_map["table"] }) diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py index f9e4f17..2275741 100644 --- a/embeddings/hetero_GAT.py +++ b/embeddings/hetero_GAT.py @@ -10,7 +10,7 @@ import plotly import torch from sklearn.metrics import f1_score, accuracy_score from torch_geometric.loader import DataLoader -from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv +from torch_geometric.nn import HeteroConv, GATv2Conv, GATConv, Linear, global_mean_pool, GCNConv, SAGEConv, GraphConv from helper_functions import calculate_class_weights, split_test_train_pytorch import wandb from torch_geometric.utils import to_networkx @@ -24,7 +24,7 @@ from dataset import UserGraphDataset from dataset_in_memory import UserGraphDatasetInMemory from Visualize import GraphVisualization import helper_functions -from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL +from hetero_GAT_constants import OS_NAME, TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME, CROSS_VALIDATE, FOLD_FILES, USE_CLASS_WEIGHTS_SAMPLER, USE_CLASS_WEIGHTS_LOSS, DROPOUT, GAMMA, START_LR, PICKLE_PATH_KF, ROOT, TRAIN_DATA_PATH, TEST_DATA_PATH, WARM_START_FILE, MODEL, REL_SUBSET log = setup_custom_logger("heterogenous_GAT_model", logging.INFO) @@ -47,6 +47,7 @@ class HeteroGAT(torch.nn.Module): """ def __init__(self, hidden_channels, out_channels, num_layers): super().__init__() + log.info("MODEL: GAT") self.convs = torch.nn.ModuleList() @@ -71,12 +72,17 @@ class HeteroGAT(torch.nn.Module): self.softmax = torch.nn.Softmax(dim=-1) def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): + x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag", "module"]} + + for conv in self.convs: + break x_dict = conv(x_dict, edge_index_dict) x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()} outs = [] + for x, batch in zip(x_dict.values(), batch_dict.values()): if len(x): outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device)) @@ -120,7 +126,7 @@ class HeteroGraphSAGE(torch.nn.Module): """ def __init__(self, hidden_channels, out_channels, num_layers): super().__init__() - + log.info("MODEL: GraphSAGE") self.convs = torch.nn.ModuleList() # Create Graph Attentional layers @@ -173,6 +179,87 @@ class HeteroGraphSAGE(torch.nn.Module): out = self.softmax(out) return out +""" +G +R +A +P +H +C +O +N +V +""" +""" +G +A +T +""" + +class HeteroGraphConv(torch.nn.Module): + """ + Heterogeneous GraphConv model. + """ + def __init__(self, hidden_channels, out_channels, num_layers): + super().__init__() + log.info("MODEL: GraphConv") + + self.convs = torch.nn.ModuleList() + + # Create Graph Attentional layers + for _ in range(num_layers): + conv = HeteroConv({ + ('tag', 'describes', 'question'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('tag', 'describes', 'answer'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('tag', 'describes', 'comment'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('module', 'imported_in', 'question'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('module', 'imported_in', 'answer'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('question', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('answer', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('comment', 'rev_describes', 'tag'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('question', 'rev_imported_in', 'module'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + ('answer', 'rev_imported_in', 'module'): GraphConv((-1, -1), hidden_channels, add_self_loops=False, heads=6), + }, aggr='sum') + self.convs.append(conv) + + self.lin1 = Linear(-1, hidden_channels) + self.lin2 = Linear(hidden_channels, out_channels) + self.softmax = torch.nn.Softmax(dim=-1) + + def forward(self, x_dict, edge_index_dict, batch_dict, post_emb): + x_dict = {key: x_dict[key] for key in x_dict.keys() if key in ["question", "answer", "comment", "tag", "module"]} + + + for conv in self.convs: + break + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} + x_dict = {key: F.dropout(x, p=DROPOUT, training=self.training) for key, x in x_dict.items()} + + outs = [] + + for x, batch in zip(x_dict.values(), batch_dict.values()): + if len(x): + outs.append(global_mean_pool(x, batch=batch, size=len(post_emb)).to(device)) + else: + outs.append(torch.zeros(1, x.size(-1)).to(device)) + + + out = torch.cat(outs, dim=1).to(device) + + out = torch.cat([out, post_emb], dim=1).to(device) + + out = F.dropout(out, p=DROPOUT, training=self.training) + + + out = self.lin1(out) + out = F.leaky_relu(out) + + out = self.lin2(out) + out = F.leaky_relu(out) + + out = self.softmax(out) + return out """ T @@ -291,7 +378,7 @@ if __name__ == '__main__': config.initial_lr = START_LR config.gamma = GAMMA config.batch_size = TRAIN_BATCH_SIZE - + # Datasets if IN_MEMORY_DATASET: @@ -300,7 +387,7 @@ if __name__ == '__main__': else: dataset = UserGraphDataset(root=ROOT, skip_processing=True) train_dataset, test_dataset = split_test_train_pytorch(dataset, 0.7) - + if CROSS_VALIDATE: print(FOLD_FILES) folds = [UserGraphDatasetInMemory(root="../data", file_name_out=fold_path, question_ids=[]) for fold_path in FOLD_FILES] @@ -385,6 +472,13 @@ if __name__ == '__main__': if USE_WANDB: wandb.log(data_details) + # Take subset for EXP3 + if REL_SUBSET is not None: + indices = list(range(int(len(train_dataset)*REL_SUBSET))) + train_dataset = torch.utils.data.Subset(train_dataset, indices) + log.info(f"Subset contains {len(train_dataset)}") + + sampler = None class_weights = calculate_class_weights(train_dataset).to(device) @@ -392,7 +486,8 @@ if __name__ == '__main__': if USE_CLASS_WEIGHTS_SAMPLER: train_labels = [x.label for x in train_dataset] sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) - + + # Dataloaders log.info(f"Train DataLoader batch size is set to {TRAIN_BATCH_SIZE}") train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS) @@ -403,11 +498,12 @@ if __name__ == '__main__': model = HeteroGAT(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS) elif MODEL == "SAGE": model = HeteroGraphSAGE(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS) + elif MODEL == "GC": + model = HeteroGraphConv(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS) else: log.error(f"Model does not exist! ({MODEL})") exit(1) - model = HeteroGraphSAGE(hidden_channels=HIDDEN_CHANNELS, out_channels=2, num_layers=NUM_LAYERS) model.to(device) # To GPU if available if WARM_START_FILE is not None: diff --git a/embeddings/hetero_GAT_constants.py b/embeddings/hetero_GAT_constants.py index 4b296bb..6c4dded 100644 --- a/embeddings/hetero_GAT_constants.py +++ b/embeddings/hetero_GAT_constants.py @@ -11,7 +11,7 @@ USE_CLASS_WEIGHTS_LOSS = False # W&B dashboard logging USE_WANDB = False WANDB_PROJECT_NAME = "heterogeneous-GAT-model" -WANDB_RUN_NAME = "EXP1-run" # None for timestamp +WANDB_RUN_NAME = None # None for timestamp # OS OS_NAME = "linux" # "windows" or "linux" @@ -21,10 +21,11 @@ NUM_WORKERS = 14 ROOT = "../../../data/lhb1g20" TRAIN_DATA_PATH = "../../../../../data/lhb1g20/train-4175-qs.pt" TEST_DATA_PATH = "../../../../../data/lhb1g20/test-1790-qs.pt" -EPOCHS = 10 +REL_SUBSET = None +EPOCHS = 20 START_LR = 0.001 GAMMA = 0.95 -WARM_START_FILE = "../models/gat_qa_20e_64h_3l.pt" +WARM_START_FILE = None #"../models/gat_qa_20e_64h_3l.pt" # (Optional) k-fold cross validation CROSS_VALIDATE = False @@ -32,9 +33,9 @@ FOLD_FILES = ['fold-1-6001-qs.pt', 'fold-2-6001-qs.pt', 'fold-3-6001-qs.pt', 'fo PICKLE_PATH_KF = 'q_kf_results.pkl' # Model architecture -MODEL = "GAT" +MODEL = "GC" NUM_LAYERS = 3 HIDDEN_CHANNELS = 64 -FINAL_MODEL_OUT_PATH = "gat_qa_10e_64h_3l.pt" -SAVE_CHECKPOINTS = False -DROPOUT=0.0 +FINAL_MODEL_OUT_PATH = "SAGE_3l_60e_64h.pt" +SAVE_CHECKPOINTS = True +DROPOUT=0.3 -- GitLab