From 2b84694ceb66dd5266e2d90082d23aff68d99767 Mon Sep 17 00:00:00 2001 From: Jing Lin <82669431+linjing-lab@users.noreply.github.com> Date: Fri, 27 Oct 2023 09:22:26 +0800 Subject: [PATCH] retrain Binary-classification --- released_box/models/credit.ckpt | Bin 6991 -> 6991 bytes .../tests/Binary-classification Task.ipynb | 80 +++++++++++------- 2 files changed, 51 insertions(+), 29 deletions(-) diff --git a/released_box/models/credit.ckpt b/released_box/models/credit.ckpt index 8f592a992c9b6cccbb17d4a35ddd2c34502f59de..d2d105569eac8b8d890e5b6402e20c7a61833479 100644 GIT binary patch literal 6991 zcmb7J30O_r+dro?Q&NhM%1|n4(0KM*B}yDa5;7!Blv16XxQ0@xG!xAcjY=Wns?L7j zOd*O;=2@m>$UMb&;@&&;Gs6e92xMj;t)_zqP)c zejLwGZ$JMPzGHkmg`P$s0fCP8ikx}Bv`7hW@8G~dUvHs*aL{u3>7hQpp}sy7{Jn+p z0?sHG0awqhpBQ>sV2IHa{~%w_P-CN&zWz)6gaQd~VXzq9$RkYXAGloJBh=R`%sU?gTKt92*XN3wvL2nm9 zAGf~0Ez`^2bGbm#PrL@dUVi`KFf}!?Hn*@eH5MrS%+Z%`DikPxr-&8w<8xyLDnH;= zg@XPr0=3_=VDfhs)cM&I|K^)D1-#d*^HnL>vzAruVzP7x~@-m`}hKj23S1-dQ*y|8e>C^tEAnmm^I z3O#!?A<+MxLIXWF@hC@N7%mw7{f4dZ4P7qQgTRQ-UG66NBe8-p;R0j7?;{F7M*dcW zz$9E?`s0eyUsuM7^|ZpXM^nZEvv9%K|0&KyU>+{8_R_&YOr1($74cLZjLs_#F~t=|bhF?PJNAhq@2L9&P@AVO%JMg40v4o_ zM@Po9$$!+aS8LMP{vuP}=Da-S;?fqV&ppE=tPd9*FqFe>CH6$eN0-io!>m`$6Ly=R z1MZfkTZNphCBdItm}g`4m?syv}$$@bIf@b`F1Ofsg89fcQcG>)+;Xa z;9vyNH;RHCYj&^;CN?r^3g<+e_f_O&;%2zAElu=z@gVA4UI-6GQDD3z5Bw!MV1ry5 z1G~n6VfIJp?mj|T6MK~HGlMyG>aECdpp8gAR{`&gwI=?BLzwM{DO{RYF7j-Vg_YV@ zq5h%^X&f(th6&F}9~)DWLK>K-2a!>c?uT_of3oF}0uF_J;e@x4Sea+Bm4_R6PKz#q zWP1WJ$gT!E_cET-#ya*~>n_GeQJIV_XlEQM4}f&@N?!LFKUM~JkRBDxtNd8EiOb>i z^#4r-djF;Z@h`*Ce)#`M1!vneZJo)8`X1it<8N2@9Hc${f8#a&7v7#z#nImI|HS*Z z46h83s(&9Thpu-gpizf0Ruo-BVXXTKN^$h_GQ7}~jKO6t z*feiE9^MjL`Mvs%> z?8&;nE4Q#(u|r|c zrfF!jxRlAfI0&zmp5r-}-DBsiIs@)!t3ch|1E25@o3+F7%Jd+d zRh~mE-)ND=AyMe@H4U5A*6^@7o183*$BMd>FgWD{Hdpqq|IB@eMKkXp$qL1^^Owm) zx7VcaqcJdHrWGjc&ZbS?8Tg^%8U&9UhqvNhi5~VFiLFo4Nr(xWGbE#o+JKm5CU9QVQQ~ zHzqdW-LUWQJy;dc%w&ApLZ%%$LB1r6K#ln4xVD2^-+p>C>i^NFzQOGt=3Q2X9Tev!l$8x2Ojp6DQn_q>;T)X4 zzZTQjTwEf4PFPOJLGJc|`q+{o^(SrS)=Nr^sDFKK3Wnaf4Rv)H5TkPwW~$!>q1FIQ zQix%ND!E9L#>1t|mteU;9g9|tg4)UN*vD&bG9IJD@rKfAGV)t380=7`=kE7n!`?`d zGzTtjahgQdx9H)!=69edEyuph=AcZ{Nl1OgL+z90a7F1MqkhQ<^A^g(;7ij`=J95z z&DufCKdmS4uN&ZK0Sm@S(OB?q63P~o;iB>0$mc!coz}6T>0e`shI=ZVzTJ~#X5`VE z%|r0!z+O!7)^zYx(!udbj!5$J*>CwzSdaXzSea|X#`RaE0pm32+Kyau;Pxim=y#WF zDV%}UN7b-V{2cq{aff_9*N;Yg*h-x9{V9!%6q^@Q`C09*L2lO`8&_ zTZDqG*MZx#>0Aj-iBhuF+gC)q1@m#Dej|84YGrdwCgWWD6egZKlw24T0cN6ks1nmc zw539!%tRY%%oz;RdjpHMSmDxBolIhCF)uROjJA6v(T9F}KK?~9)sHNne3 zMY)*~zLdhu(Z`wl13WRHXci3Ue^u1&ErVet6t|hAqvhyy9L(8<)NObo zUWn7i>*9WyI68~lNaLKI({Y#eZ69To(b+*Ibl-~`qO{zvFur9jleVq`G#^&M!^myW zloTt{nSF|lc&dz_BomnQSEl%}RGayrIU3csS3qa|DvTJy$MpV@knqX|#w_baVw*={ z@#2#p9j1#q<7?ph{Whk6&9FT2@X71M+19yW4EsidO*>0P0%?NM&%pb}w z9%jI*&l-*$XD``kFVe@?tKJgl=y9a<))>%gI13)7+i?@S9@9Ebl7s{|9NE7LPkvnw zF$Wb;Z>$O}eHKTogU&$5{d97?{VlLHpPAgVIV@*07ekC?>C|`OnECh=3=u9scS~PP zzJ83|J9~_+`G6MMOmEOBk83E`>LKMzUto2z`oSr;HaO~klRXy_#^k%0L)No)s7^{@ zG9=o;_v%^Zcx@vIr~l3SfoI7Xn1B}@ceH}e3UDB41M zDUPJIcN}SU#a80$^$jXcsN&Wx1(J5Xn(bhR(rhy|=-PH1Lee%eOa1j=|L(Oof0-%z z^V3kv!4jW^ZO7)X7uZ>)$H4E?MdpG^F}Zp*9u_X%KqFO0P`9}+2;HZNi}0AJHE1xf zxn^Y4dKv6Gc^{r=Po=Lm#MD)vRHtg?q1gHKIoMeUc;WBTuu*aX7O##ZMf*OJ?m8Y` z^_hW3->(IGz9c?BmBWTv)qz~|dT=>f2rHBHuw+Ok+GtgXlHNyw$4)t%y}%fE?Hx+) zcpZm3>P|#bO{|Ogt1&TK0&V+sGfEn1D6zT%TaNC4+Xf9pV4X)_F&pW$_mfDgYXKQO z--LP_u0f9Vb@Do45?)A7W5!=m#gZ9r*qRW779aA6;?gciPd1^$X%t&+H3BC$PDMkn z46v(G0_Afba8CD899p~!O@^F+2WMo-1zjN>8*W33g?rgygFBefr?m0P<4nwY`w41? zd>~#bIn0Bq6nJ)JJZcma;dId+tVwMV?Obsiy4SQ9C+TSD5la=waQ6!CU z8)VzEzl#oVtfutRUOH)C0R1L#f_wW7SnbmJq;k+fxG{1kN%K4fi{esom*pN_xfR8v zu#MQ(&`td4G+u_vA{QbM+qtY> zP7;n7q{^DNQS{JoWmozhBZ0g=`1(OBY5g#dwCxEe9=(leKgA_5?TIuyQ7al3yeR@{ zEgO-|&D)}yZ811I<2@;STEW&_Pb1lnxx_Nvi)p$mNzc#8#5IdHkV!RbaKW4iEEnB| z(v6m6VWS$fg|rd}i%}S-9*4c>AK}@b>W9Jdfcu=6qk{i77{1sB_ut$A!#YAS@`VyL zS16|&zcesS4QXJLvzoqfZ6p0>#Q`V8gApivfjg>FP-vG7oJ4hk&6DZJ8Z~sh^NqPu z`i80U_Fy?z55q~`LHt^M4dLZ>%wpbS6JMp?QhEUzmz~D>5i#VkqZ-QZzCv^DxJX8> zhv_%EcwK^#jO3_I=x@enCClVcOTi4X&zMnZ>#IywU@$&24rL3>d_b0)hh6WNl3ojw zFigURcnoWzgJx`nobGKjMBj_v&p1TA7X{O6lZHU}u37l8_z9eLh-25^8iDR(^I^ps z4$jth#VJE1k&g?pZS;O}&EYWE=vw3A*RHT(dIdAsD-1hzcEX8hN2)!ygRS2zOC#IG z@6W|U$iU^hN&hc{aZW@Zay&i`T2Jm4J(}Z)p3ijQ%*;O_bC3j4(~QR4Yc2kixcY-VjbiMjrWrLVJza=I!ysHdRZ z@+f@Q>V(R+mT=+RA2@uE6CS;=2YUI3W8XJJiNaiY*yLb=ITQJGCEthjm?BU3D$&$r z@+|gspLn4ACZczx4RY^K$B15aAYo+5WZt~UE{Q_m`)HwIqm-y)=5>f^JWKWqOCa3E z1FdQ&<1j5DzWOEy_TfH|I(7`a-aCbANG+ge7aSuqre=amfdu{B&`4C8^kI}*Z|wNA z9&&=RAk0gW1|5A#DyP-LTu}n*oOi{YPjk@EsEM7pK@}?lJ28wcz^JoH=o!smibNMU zHl2b)m+Qd6b{^T67)dk=GwB}Fm3Ttg4U)XYI-WHT)?XZeMqTM_pGo)0$G0D#wR#qg z4DrA(C0lUDmqe5nbKUqC&80Cbt zkx5Snqico-(6#S~Pu&c>sji9I6{&EAuZ{63&q?9qvtaltgQ?rK8fRBtf!+$9sBmZ~ zKCRk=8T_kQeE1I%;`fp4DvV_KTr+y$+cI+8>M*_2kjGPZ)Whzh0obgsgT?7{fUJ^c z{p?y` z0R2cBnzZrp&p(Qo;>m5SiESyo8+0A^X1idQ^BPz;dN)xwjwMqY`FQ*>)w{dJiyO5Q7GB zYjEL)6f_$dfnIMl*vn@EVcNBMjAitbI;rf_r1wKx2--%HhqP4VTUmS z7lH3`eejjyqVw(s!kHe*^g6bT9x}JYi=Wpr*F*s2=DnzL(F=$u3LuAb3t_Xe3{Jg@h9#8kl?fLkprt-5V-BV-u zwf$aU`91n)FH_@}&}(9rpQC?uv46+;*_-?o=Z?7k^Dj7JAM|&epFO%?aonW-4X4M~ z{2k}#-2aM`Ew0@BxApwv@BNPRbLM-B2mfiZ1OJ2b*L`#5OG_L6HA7Q!^gnZZYDGQm o@6SB3aop2f#1K7A9M@BZ;yn9nr9GFO`BLAfON*m==Kq-cU-9qm=>Px# literal 6991 zcmbVR30RF=+ispUr+G3oOCzdxt(~M&R60^JR4Nq>wzhC{&_jKBg0YalY@vcfRlX|NptJb-mYKYd!b9*7H8=TJOFeUoTNHAt5O#p?@<) zA!VV^MPX6Vi4mj2LwTW=3m3$=xyTCz{a`Ca!ouQXVj{wL(eZH!GCqsKBNj!3yGDoc zWVk}6eq3RbK;;4G*qDWulcVDzLKj(CE{TYajN)-c!g%om@Rq@gdC@TmGQo=?<}Qwo z2@mE)#3jTp3Qn|<;fk7=1WFEgCk)ICjf~)meRGjU!q>=>@X_Z*5|HaxDn*NP<*ws7+>h>3VfUP zfOy>%WIZeRf#-w-!(M9&W_8TR)d?$sTON+TMY~|+v{C3D|A?&l056XSFGu;}_dVicZ2WS7XWziAF7 zL3JTy%Y9@|7+&M=*nEy@daQ=Y^A^EpKRJ^8hXtR`kz=RSrjr8+pP<;ynz)9XV6E=l zajH~d=rJ^64ecu(SGeVXM-Y&Mu{)u6Z@81Yfh50G>3Xj9c5p@xke%%>F|%%#2ch(5aaQzb#E6GY`*l%4sQMiybxSjrw9J znW%;zj#-ftVx|0Uv%SRcqa^518Au;fM?!+jnC>IH+3k;pIWfl`GwO1Ae8+P|kQ>xT zBAy3;o|~bQclRzQAN33-`pbUUv^j`~#-)-V6-{z<>29*dQUq<%jYxm;HTGy@5lO9h z#ZK8egH*PhWIwJHBJzPcc=uzWV``d!sW_bvS4=O%U>@8k+}M;SEF?78{#zcH{M$MD z*A3w2GU9)AjykI?uqd&S+@7+RNEnW1dMrM$?Rwe7Jm4gx=5eskLxC!s^g*$=T;hIv z0;r5`flu0x$uTt^8jX;mcxM9Iq@IGgmLo8rel(hS?}4*p{{V0OOpG6%izC}V;-VXg zm>F4#hbKDXk}qShl^`jIvLlM3m+0igopkO>b(+|`9v_yQGDD6nCoOBOv2>#!O0afd zVt*G?%(<9i)`eR-=U~#2d~~{*Kzj6(0A~3M<}yZt)~%`nEjc4W^<__i(p^Kr_Oh{p z!yIkF;z{24n@usU*GXk#n^)Cs^ie^#@)kx>Zvt)_cbe@wGZmWZWYHo_mOT(>K%cb7 z<73OSFdhuB2qgU)Cb{FF#B$HRcwTTLfhoYs-Oc8769;ch(iWD^Kh%i9 zXpPV*T4a6&%yqYrA!&s zV4y@#o}WiE@^+CY(NfGd5iYo2>4*6Fci4?;T6AmYZW@>EOc%5!GHV_yquq8>I;q|X zZ#{YsA|Edjd!adaJ3SsNl=Tp!N7CFi#o!iX!a#>QF3T4cth_uA@5V_89@GsH7%E5! z9xU63vL$VJS8SM|yOH4W>nG6J){<9*d^LhML3;JLpUlcZqvNJJlHCePcpq~ftx*!OzgY~ zVf(A#md<9pU+<0fa#d(5bC9i;T+P0G#=&Os?QmnhIBpyfh;^o0oSt6EVcr=&rxQk( zQJsu_>U-cJ{e0In@yBzjzIj)&(LBp z7mGHPI{Eh3g4;QJR$0E2H1p?B;|dcZv@VDIdbx&fcw9u3l9Ooo$i4K#zA1FQ{c)P9 zJOj_>xsp%KdBp6VIm&AXu^*(0i0Hl`Jlw{|6RTqJ?Vo12{GBXD&MC$ZwN{v!t&JNR zo{{$8iMaWB8OlB-RG_zy7F^lR)^&N}d#?5rWS3HO(iX^ z>#23_PRNbafN|?K(a*)5V0HBn$aHUK7u6Nv?`z61vTcp4_4Mpwyr3<-UXkM=l z#as>Kk$@sE3|$RY-%l8$4!SXQ-q1z%@J7rQJ6gR)50fxHt2S`9RyP) z;m${YCd=#tR=YinUd6+dWj$i4ljqkn~Vw`a)Rt9A7U+9LSe`yYknmCl|*I$MJ zhxvG^p@FoiJt0Ll^~~{_Dk$At&c-Nj#l%Yyuw#-1@=Z;!TKo%{x*!#%&ya?8*LJdD zdJS4GWfXp_}CpoR*M5|ASlok8ps|E#{yj>j2M>l|3wgi6C z)Wcy9MNsq160&MV8vAhHD4Z#^nvE{wu^XFxfH&a+3Y%0@qeXetV%j0XYDH0_&pOmJ z?-b?4?#4}J8?mEA0nb&Bhr6vNcw*B<2u!@fOz*u6F*o>V`?QkrnO=)4B-XN%!&l?N zThlPu@D_VLHXDq7Vl5PCgU!+ zzd#y=-z#A7kr}vuaV})`rDDDA1jzXFBGhb^q95 z%M|)+LMX)o1?;jdBd>0cfx^)oc(zOmqo@1A$7LI#?8QOw5Jqyf^a8X0(L5Y-*BFD& z4ngB}X`pImfH!x{#NkdpIQ86P6hBptemnRWHF`PbSDix_&p!OUQ57wUD$!DOh9cPd%b2jF=@3y?>!4~P4#o&1;!Dvk4FJ@rO1t(Y=RQv$bGmPU#fZF`78n# zn;7DGzZ-<)o56WW21&L!1xce1!sob4WNNx64xd_zlb76p^MN*?td@x>E9YT+({VWS z{1F>J@d3nLybddOSduM`F7UZC0@5~&!yR90aKhU}tdN?FPmdSE(@o9P(Yls4A*CzA zhtM}KVrgpcEOM-OB>M0*LH%wn8dj8$%1wrNj_*uPZA--`TY+3!UCqX9l}6b=yP5Ed z&7?pJb)4i(_f{b#D{jMO-M(HSuRfZS`M+#{kPEVM-P&t`{vOcb93xhzW~R3 zcd(nJjNtu{{cxdR6S;9R2fN>Bpr=DUNiveh^>VwwKT8?iJ2%6FSrdrU$6EaJy%Y*X z>cGP$E6xb-cr~c(&+T` zQzY%`o#?0Dg90_mS=@X-p2RA&2jI zW244a@`#Ex*Nl^=R}bmFp^wgol)y(qmrgIxp!a1ds{gTyEPGbn1M zB578{&xzjfVr&*Z>@lM1*>*JbYyf>ywSi8NZJ}#yTFE8xCKzh63<~z9V26$@oYm#x zhNMSqVpIsUcBGJ(8+u@!;1ro~!y4lsC8K_06Erw9;4f{H1TmcVSl8?>XiZ`S<9lQU z%aYBpCAt6~1gWA@WG+1YBY{m=evvuDUU15|TLP&A-@;V~weSq*3+Sy+!c7LpNZZal zY%E$!y23m$Yi%9~n|q@CyY*O|+0AdOSBH5;8ce{zeSG+CIn%auCiw0*2H)MgX!5P; zeC@T;sL?!~*j5Z@qu50#oc{=RTs#lHSA39N;g51de9%uzm_g-ID+IDW4jjY$=y9?KrUq+9wZ zFWC=PW5$B^NDs7|lU>&y-44>h);N045!m-qfSlLH0>{a3P@?}aE?Bh|8{`k5QA;U? zhxOsG*hq{I?5VwR*O%p9S7k3&9)(NSH8F_4p1(WR0<@iH;QXFpV0M-a#vF0*pVS2Clv9m|%o&ar;>fJ*Tt`;7 zKL?4O_V$eh&{&*(_u8D^fUl2zX(D3z*5Zg5fMCgk`V-RI?!sv z6e@5(NyTO@r=LV4z$Yw@PCM8~heVDiQRS~d*ZMPzOxS=A6V1rH$VTY8x*Oz=A4SXc zbMgG0D%{#a0p1$8FJlulz21)da1U;ktU;-j!UEN%rN|vu zj~P=<$yhso*0fxOgF(cX>J4^W`2aD@ema$65HM-V8r$ z8lu{yYzS=^hhDv57|ajCqbG!EtDh-`Jyl_<{o?S5{Z1w$`Wac?u?I}tufx10YvAcq z9$&_iV*e*GOt`ZXRMzGbvFd0v)XqiCuswKz8v*f>`_SIi0GHk_Leh2|d-mADf%5gV zp*MlX^STJP>lYek-9cp^G=Rp8EG!%sduspe29ke!4SFv=NMv)oh*98Xl;Eh6Kjul% z&T%7Aac3zj6p)6u-$V|4XBt7#*ha7obVXAoc@X|&M#kmp;KDjpNa~V;HHzn%jYCb~ z?6yMWlH;WM_lf)i{Q0O^X^r-8l1bgO;+_ zsSxH+30QfwnD0|H2aZ%fC(aA>*eMy+AeXC$$KM)&U5yJKEV;*gk*;wv)7C^c8%xKW zcg5g(oI$A$4S~DTa6!^pJpp_~>;iQ>t6GGg zXLq5nR4eJWcO_Xnmr?DUY#JtggO!=^Cy}d)0mnsgsK0+0=mz@Zv83B<_HakIdeaQ@ z&sMWF+D4e1q>BqzPQm^h8yxdF7xY~u@X$3`NI&QSX(22gdRRvHhg8s&K|HeA=LEfU zE{kHf9OrY&UXV5B5b4NOxJA|*LRL>>&Yjdoh1q5}tOcQ=KLvXWm*LnaOTk#y?W^8+ z*rQDq2ClK&8wY-I)&tCq4z{{hKaXt4Qu2@gL0SyTC5 zlOC)w{AfQ{Sbp~YUQ^ck!S(h4%lF0q|1Hi}J^E*y@3pxfaf$}& zNPmwrsLTG0^L^}p#3>)B-29C5FJ13vobMw)SUmWbNooEI=dbtX>m?y!@z)GJ(NX`L uJ6J0kw0}HH2872!^BaH|v;n`tGL%r?Un?CP_Vp6~I$dJGYjFOzx&H$Jg22iE diff --git a/released_box/tests/Binary-classification Task.ipynb b/released_box/tests/Binary-classification Task.ipynb index e4fcaea..0a1d85c 100644 --- a/released_box/tests/Binary-classification Task.ipynb +++ b/released_box/tests/Binary-classification Task.ipynb @@ -341,7 +341,7 @@ { "data": { "text/plain": [ - "((30000, 23), (30000,))" + "'1: 6636, 0: 23364'" ] }, "execution_count": 5, @@ -350,6 +350,27 @@ } ], "source": [ + "'1: {}, 0: {}'.format(numpy.sum(labels == 1), numpy.sum(labels == 0)) # nums of 1 and 0" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((30000, 23), (30000,))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# dataset are imbalanced by labels\n", "features.shape, labels.shape" ] }, @@ -371,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -406,7 +427,7 @@ " ('device', device(type='cuda'))])" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -429,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -446,23 +467,24 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch [1/2], Step [100/3000], Training Loss: 828.4427, Validation Loss: 1241.4515\n", - "Epoch [1/2], Step [200/3000], Training Loss: 94.2689, Validation Loss: 64.1411\n", - "Epoch [1/2], Step [300/3000], Training Loss: 0.3855, Validation Loss: 14.7295\n", - "Epoch [1/2], Step [400/3000], Training Loss: 0.3561, Validation Loss: 4.1548\n", - "Epoch [1/2], Step [500/3000], Training Loss: 2.6463, Validation Loss: 4.9783\n", - "Epoch [1/2], Step [600/3000], Training Loss: 0.4240, Validation Loss: 3.0171\n", - "Epoch [1/2], Step [700/3000], Training Loss: 3.8731, Validation Loss: 1.5793\n", - "Epoch [1/2], Step [800/3000], Training Loss: 0.2598, Validation Loss: 0.9340\n", - "Epoch [1/2], Step [900/3000], Training Loss: 0.5623, Validation Loss: 0.7461\n", - "Epoch [1/2], Step [1000/3000], Training Loss: 0.7324, Validation Loss: 0.6750\n", + "Epoch [1/2], Step [100/3000], Training Loss: 0.0000, Validation Loss: 751.9165\n", + "Epoch [1/2], Step [200/3000], Training Loss: 151.6677, Validation Loss: 147.1012\n", + "Epoch [1/2], Step [300/3000], Training Loss: 11.5891, Validation Loss: 11.9286\n", + "Epoch [1/2], Step [400/3000], Training Loss: 102.6795, Validation Loss: 6.6973\n", + "Epoch [1/2], Step [500/3000], Training Loss: 0.2132, Validation Loss: 4.4743\n", + "Epoch [1/2], Step [600/3000], Training Loss: 3.4273, Validation Loss: 3.1016\n", + "Epoch [1/2], Step [700/3000], Training Loss: 0.6728, Validation Loss: 1.4565\n", + "Epoch [1/2], Step [800/3000], Training Loss: 0.3786, Validation Loss: 8.3835\n", + "Epoch [1/2], Step [900/3000], Training Loss: 0.4249, Validation Loss: 1.1698\n", + "Epoch [1/2], Step [1000/3000], Training Loss: 0.5671, Validation Loss: 0.7559\n", + "Epoch [1/2], Step [1100/3000], Training Loss: 0.7306, Validation Loss: 0.9674\n", "Process stop at epoch [1/2] with patience 10 within tolerance 0.001\n" ] } @@ -481,14 +503,14 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "loss of Box on the 3000 test dataset: 1.061116337776184.\n" + "loss of Box on the 3000 test dataset: 1.4942820072174072.\n" ] }, { @@ -500,13 +522,13 @@ " ('column', ('label name', ('true numbers', 'total numbers'))),\n", " ('labels', {0: [2282, 3000], 1: [2282, 3000]}),\n", " ('loss',\n", - " {'train': 0.5675961375236511,\n", - " 'val': 0.6755267381668091,\n", - " 'test': 1.061116337776184}),\n", + " {'train': 0.2478877156972885,\n", + " 'val': 0.6973171234130859,\n", + " 'test': 1.4942820072174072}),\n", " ('sorted', [(0, [2282, 3000]), (1, [2282, 3000])])])" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -525,7 +547,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -545,7 +567,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -562,14 +584,14 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "loss of Box on the 3000 test dataset: 1.0611168146133423.\n" + "loss of Box on the 3000 test dataset: 1.494282841682434.\n" ] }, { @@ -581,13 +603,13 @@ " ('column', ('label name', ('true numbers', 'total numbers'))),\n", " ('labels', {0: [2282, 3000], 1: [2282, 3000]}),\n", " ('loss',\n", - " {'train': 0.5675961375236511,\n", - " 'val': 0.6755267381668091,\n", - " 'test': 1.0611168146133423}),\n", + " {'train': 0.2478877156972885,\n", + " 'val': 0.6973171234130859,\n", + " 'test': 1.494282841682434}),\n", " ('sorted', [(0, [2282, 3000]), (1, [2282, 3000])])])" ] }, - "execution_count": 12, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" }