From 4089bcb751efc095c60a7710825c98fe65b7d4ec Mon Sep 17 00:00:00 2001 From: Shivalika Singh <73357305+shivalikasingh95@users.noreply.github.com> Date: Wed, 16 Aug 2023 22:22:36 +0530 Subject: [PATCH] ShiftViT Example Enhancements (#1011) * Updated classification with GRN and VSN example with links to model and demo app on HuggingFace * Removed warning messages, etc from md file * md file fixes * black formatting fix * Updated ShiftViT example with model inference section and hugging face links * formatting fixes * black fix and minor edits * removing generated file changes temporarily * black formatting fix * black formatting fix * Added suggested fixes * reduced code * Added generated files --- .../vision/img/shiftvit/shiftvit_46_0.png | Bin 0 -> 26263 bytes examples/vision/ipynb/shiftvit.ipynb | 330 +++++++++++++++- examples/vision/md/shiftvit.md | 359 ++++++++++++++---- examples/vision/shiftvit.py | 207 +++++++++- 4 files changed, 800 insertions(+), 96 deletions(-) create mode 100644 examples/vision/img/shiftvit/shiftvit_46_0.png diff --git a/examples/vision/img/shiftvit/shiftvit_46_0.png b/examples/vision/img/shiftvit/shiftvit_46_0.png new file mode 100644 index 0000000000000000000000000000000000000000..ad616c3bd17b25836fc142c144952c12b4db8e5e GIT binary patch literal 26263 zcmbrm2T+skwl+-fMXGc(D7_QuRhl#rA@nLO^d`N7^d_S85=B8mhtQiC2q*%fW9S4y zX`x7OUp)KleZKEK|DJcx{F%vQxbrZRJnLR-UF*88J4sI-Yf_T2km2CqP-;C=dy0dD zCyj%H8$m*ZUCAi4Wx>7>_&wAzBEdewNE{Nef0KGYGV{a1A*cQG#2sIVxWq0>KUX(> zZs_IoJkZY95y!*sxwpI5b9a|#cLE%J{am~}C4|L=#RczxpFj7O6%qNL_l3QDokg-a zQ7CGi!V=jJ6bj>(C)#9Qg5qkc_i17*IHd(*|A%nVsr4t#E(&$P4o&K0f0urGYzZZEVzWkPPeylW#$zTp}Zrf#h{79GEs|K?HCvIIA z8olV$FOZn=WA`5+z5C+U!^E6D_D#H74@vI6U<0>)`md#pWxRwxE5)(oY=7Peq7wGT zzT8@s`JW4%MLup`fR02ie6ACW{KujN>F#O*py#d72zL=BH^Zr1=(!EX^x}MNXlv%wk4Lx@REqsK=Lff@LA}nJ!qb0dUQx%&(pzYRb(;4H6rO*PSCLK^uIn=2L zx=9W4?|@Q$+)ROnpgxvZl|y?>rfxX+2DCv|qMI8+(2y4BuEEqW^gOS}BJP&HA*9O& zdku&Gd|&1<`;#fJq(K+J0XkEO_$p{EHx7E<2|doZ3|;wn$Z0a=1>KzB+PE|Txq*(P zxWbo%K}FJ=r(M==FZiJ6^Mh?C6|Tc>Lv1-5SBcOX<0)&T7H?ZBIc<4f^_8gQ6GXi|T59GAAN{gV%|cVh0-#0S!hb z>g2%;wfDwq5qAL8P}iN0hD%%k&etmea_VZuG?_MEy*W*V05i#n^ug6Sj;&q~QXU1Y z&BlD(#C{S1z(EG!snP*OOV48P3pQSjPCM$vL1UUcUU>Kl6^#MSS#u&RaZs0rSp+bX zlqkKS$P1WLTZ_rOK3<5m_RjF`6Y(ze3zXeJ75N2k%wrr*C-UfyeNS78{1iso`gnKZ zX*>!N=Xn8`g`x0>ox-hv5yv9GrQ4ZxiMpKrgMf{BRW9q8IW>pH#E8zsNx6;gl|mi4 z4auazQOBTzI}RnY1IN~hk+(9BM%v~yP``Ho4H{~~O#a_Qiu_oM{H_28b&2lJM-)~I zhmr@20KVT8Ew>~$=E*iJE!DcY0ho{}Q_Cp*FWx0q3cjOT>JExZ{db8o=^TP+4%Ig- z=Vfx9j|lmWczBOoP08f=j_hQvs9!NBG=E$~V}to|^8@s}0eW2FT9WvUF|)mkt^3L! zRxLiCTYV8`w^n*^oq1u@y{^b!dL8;m4Zy|Op?@ng+hj<}ch^Mief>w`dk*v$vy0?o z0QgQg(8BRlYSQEv+tCs!`ex{O#k_XMSta<&7yCb)=^@h`%-6$hO?xgc*hCR2gVULW z$2SmY8Z;yUx@$CrOLNxLSs(F;%v(*wTO%UtqAk(#E}&_oji1KsE<6?zcQmzy*1QTV zNQAMT$mprnFlGiOJvtf;KPR7e2F|HFH2T%JPAdHvX**1o&UtyW;*ScoRMXaAU;Ldv zVKa9;=NN=5mM=;3{fsEt_*;|p+^9HU{780YK^RaHrni(NyHOm0?|5%5tMas z$Va!0;)3Wn@sTp<2BgNH+5dR3Zc{=qGdJgh$hH~^|{J8i2Nob^*+T)oO%RFtG zw)F=a0m*}F1~)e^6Lk%zPSTh%mxSPpFyck!w&b>HVk7@-Xm;YDl-gPipgDd?5HRy` za|{50npZ$EVfa6|*3u3N6SrwIF&|wkpc+O;^vuO{&j91K&}zUz|4iWt>RMwFNAow{ zfJ}?FTDJ&%ec&-o9v2|xWMBLwX)r}xUPYc4aO22!=uC5#l~|Rvk};F@?51sq-~O@M zNRD0XffMs}_0etdOCI5e$z@MM14phv@wYc>2fafhb{_sn0Qi0x^nA}Lv(fI@YGr*N zupoX(`r~Y*ZRWzLpWj4mrd?0X^UaFV4-K_7|3F&Vja`OJR<*VCOeOrPTbV68Zf7LH zrKzb)w-X1Am?f#OwIpz;uJ^{2`TE|C3QgO_gYpOqwQkvijVTefPTEj%|MZdZIlV^g4kM2a%PkMQ~w`rK<{5|pycB}`obwG)t}BGPkdYY&$at(ggp$Qh1hn5l)E`O*62wY>h)YF(ULoOd{wAh~ma`lV{+v+T@7%?2Z`o4g zoK{dD#0EQ7D1z*4ZW0Q)%6mC7bZ59DTD|yjQOmXzFkPa7y)cfAq(ZSCi$g#4&5ce-?T?wu=paUAiQFT)rntS$UQe<%4XiY%2c%;q&R`&l>4pl2fIX7c`!V~ zmED=ST(GwGSF~jD@Mo>&kLcw7+oFrL{)LW^#?AhdK)2)`W~5{(u#G{;+u7x&R@W~szKkV4$1}}-a!cye6zx`T z@2O5-Yr=pyX?{!GW4-&Ea3K2lG&v3O;VB>TRUQ!h3+NhwmU)WsiCgCR?*LjZ|3LtK ze%{fle?mOkOZblnYpOVx7p;Peo>gTD^3>ST8}kY8&oZBgQW|ITwGOV?RSISYfAgI% z=N2v*6@AUT;dBVbY1ZvT*r^v_@FJ>ib2&V{)W|DZjy5h@_eBsM1{E%4IIib()B3i? zgiC?H(M=<$1a@YaIKD0Q2}$^BdMSac)BGVn?B(|7AM0)x0e4L0XHX8$-h7{DGMslN z3%0LoLFlD;lWaw-tZ5*t`p~bbF|29JM~n&=g5U?kbbBEh{7sI$BWruzdB71KV{_lT z|7*AY^~)L{9oQ(|sSIk*9)!qw8YgDwSAS2l0km(ousFePzbLxj9L@zJu7vh&G4_L%&0UyRE5Ch`kd6LS0>r9)8^6s$fAdIUWnQGS7OI6jB;{f zoIGNZc=m!D5$QS4S-jjwcnJT42xrwdjwe<#*^qO|%ycH$@k;h^DfaAC`e-wf@`YRkU z3wz)Xv3YiJaWm8T2GLTk%s%rJ-(T^WHnlBc6~)HSh@TL3peLs+=Nd*1UV(-njcZg5 z_lgM??Y|tDD9IjtoL1=*C;qd$CjOAA}9Aw%+i{i+h`uul$|nqE1I4+F<6PR&u8 z>qk<+_|W;vobb|kfcrAV}bzasl@M%9&mmcE< z3(R0h&a+&|RC0{uu3d1~r1RvPec@^*mn|#WhX}=AsrA1A;8f;mQ`giRMK^azy`iAp zfyIl*$m^}+P~}Q0`*UACoD!~*0m?g5Dr}XXHSEsM&y~rW-^t84=lYeH!I)sL;q}i0 ziMfS1BR}IIC95q7Q8C|hv)M2h!s;w@?&hiUsrgCUH{^$c$e6kMzee6CLHv>jsvwZTY`Ft6J z0F4wAV(ZNdv80yr+jygG7gNFGYa8VS9DXs%Cw^2d>*sf<-->13Wcf^Jh7lMmrN}CVJb+3+Vv<3&wKaw zEe&k_@G<@l54AigN``zrYI-(t#-hdP0oX&LEULM>k`-nOZPin9f_9Hd7ShDpY|l5> z7I^(9m8?^#%sAy)npT^98K1eE*ifGv^Q53&YQx&_q@PGZHw@+P?smx1uNdGgr<0n#^<1xHN%P3izq(aGU*q4BdfVW}ZX7|B)5xNeKnOQxqJF&&!gU-H;(f&M`@E9HZ zC$CvndgTIE4xJiK(;E<$#iJ~K-FA`~ecOk}$Xm^65C1mmrI0VP;!V35V^hwBZY)vN zTg=>z?ks7BE21o{DSE35CC-hnvQr+&7bJ~cWD;0yw@v#5KpT8&4Coo)tvgc1LJ0$A zN19~zK{xM@XAe~>f12T&BZH#xh^9%n8f-h-p+b~C*ICl zTzSl~zDXn{?0#eyC(zfs{=8<$o>TuG@(v5-3O$f_R#KvZ032Tq(k_+0PIz}8qVRi_ z3^`J)fD0q__^8n3gumDVqlXb&qexKY8^t*J1a;q70|KIrZDO2U5>|%9Pk06We0kIJ z4WjjHKfhitEz>3PIHUqHm-`y?Kj^lp?0PRQo&O>2{nfR!9^wy5lli>zmlF7|IUn}Y z$A6u9DUS2}@~7(o;n*X3insNZ?+G3Zc{|*vmxJH9WiOaj^!Mkbh9{3^YP}u%a$8xV z%w>cjOmb+pK;jj9wc6vYV4>IsQ1%m&rGlbv_uQB?OlQm27T29>h#tp{Co=hgPLKCe zLb6{tr)9FhC9ab_XLpW7MjjbzxaXP=yt4x)xm;|zp+Y4H-$;W;<&W-xL5{fr?R(o3 z$`@O3>S)H_c1vxSjV4jkPczonzl!UEx|iRjsDvK-0Odmy#b$WSzrv9`RXz&uPj^f2 zgX^km`vNgiPsgj99fArRM+!?-Dw9@&n6@@?svFBWVa?h5kYTIVTd-3cZo-6zG{~=C z_ktxf@b-{TC}7nBInm9o?D+**TDqwtEP5{Mg))$xT}^v0R3db-!TvT_6}uPpR~q}# z;F#v8DfXnmf`O;syN*dCV*X$J^sl1(Ps~BDD?Y*+@fBisod3X!4BWEhEe!4o2hTRV zO2C8#-3Rw{I!j;y8cVKjYy-`!$_N1irqKss1I@v+&hLp?7o8*+W<~ecDYDG&M8(zb zvWhPnEQ-d(ae-kMLms!Ik3ZE$5}DiElY>nbU&OCIn|Pa|S8&X6IzQC1O|)}fYmJY( z5V^Po{xRl?L`>OtxVXC#?L;fdQ<#fKBM=Npp2$_Nb3t%rP9*1$i8U2Us`f_Gb>>O{ zyn3!2a;Z-p6L1~35q13Os#uaEL#E1Sg!wRO3S+O~>N6=l!ePrI&ifvA@+6U9Kwk|1a0dc@WBcpH=M{DkJq zykU!)oRN62FIoHD`a;$*GZO3Nv~bhiFli;`iTW3KL50D#xPM;-w zYNM|1vdNgGe#z@zAj^zE(z86w`Qk6g)oza$uE*DVKB8N!O84bQQg3)g%}C47_d5!L7o?g={+BZb|_au5d@RT8oC!d+r5fnHNkAz zB5C2P)I&sdVXCJ!kA(+dYs2K-IH=cCv+K_^&OU35yF0RT^Nvd64!_Qh|W zxrn&tLUJ41_K?4ZR7?i`#EToSEcqxxYPZ@oHa+!I=IyQYOgpao3iQ2!%TlfYcT2DJ z-)0bYuY1(`V(F~gy97dWNo9_Dwyfzpvqa{YVqxSsUzLCWX}b&U@%M2ew%Ws~*62H& z4^;ykby&fn0d`8=CyN~_Lry6=y`v0y-CiC{EZVCfcQgB%5b^9CxemJ6k>Mp5@&9HG z|Iv6$A>`A}?{C$Zws8Q_bEEgO`_}q&8Pg2269OUBJx+>}i9~~0kL4+0-Xz0GEs-~R zkT-8u8@rCs)X~{P%YLTurULad;>memx)lWvIXz`rVY}8H^N|Y^r4{{)P zLulxL3E9hy3V|n5Dh2+l1jy;?howaz;^z8=3oV2U`XtHpNEaFAQq`El(@hT31yGsm zF;eE2TZNH$+s(-@VDVG|_^)IiU<91^k75_^0Z38`owyn}(~KSyW>NYT!v zd154Jr>C@w_3hN1(tP?vz-0?v9r!RXpXU>;``%&0<=%^9zPjZ4-1Q>ssd{%Ej~+nO zDe(Pe?eD1&`_O43klLlKeD}bwWgI$7(^&Ud@m*cL(Vw#^x0K15YP{U+Q{U{BXxl23 zPe@0+HC3MsRYAb?#^+AGqJncAVe`iYwJKQRWX6?fD?ri>@P(Rj-|fQZHApI5|=nhouS~ zI)69lsB}1;hpNnFD$iVh9T3y|J&ZXg0q4VyAJqBLfpEz-Qx5(vSpv!k+gI+8A@csu$0>hWajn)5ki-z z8bC2gR&I2EikLyu-N}?67Kw`!&N=dw<}G%oRTYSKRQJ_=B)4Aw+R6`np6G0?QW_n~ z!-eW46}sPIJWC`fg`Nv$u+3=;17uhU9S$+$2#{NtO%yp|x{c9LUCvQV!gy(g-Ub_{ zw-dnofxA6&Q#;*wNm1+F6+!d16hN`X68ICdB$9kcr7QA@Oe2~8 zZuf=-nI~0+iX~MvD!+)@G}X z&2F1~)I>?-qqxtUIgs6`T{duZfUxGwaJ;g`B<}O<3m{`o!|YQf_CqnHH815iEvn+}up-4W=WlnWw0E zaSoW{=rsc{r{8a=66vr_UTRSFmM|?AbRP`s__1D~BEZ+XWp$qCrsBprLtsAnC`{qz zAXGVJV84~;`hruLyrmc~zd!kQF^g zq70XLd=*`uw55z46{U<{-_NWH_POK7Mni~c5$IG~H&UjGQbo9ttZ?!`(PV1U=bNjx zjMX;IPinq54#3Hh`~joy)NBc)3tozTRA9Fm4+4)ciR1cT6W$W=HzgMo;2G-_3~!#A zU%B7x^vU7LS8UQKhRsX=q`O->Oh`do`9{P2!<>b5DfcC!GmCOHR}RCbyF^`{Ho3~n z@){o@Fvy+iX9$jJSug5*YqpPo2H8|n6;fZ6FFop|HnG1r$yp78l{ezUisqVQn!}j6 zIR-+Gqq2uqZm~$5k0{gXA$;ILoz)`!_w4^mx03lQN%gO!iR-CU^tz|b!k3l(OB)|7 z%O>T*0WLX}0^XQ8L6vc`sBR7ga$^d0?7r^7WP+B$7;Bvs=F)}0Yx8xNgvbEQB4xfv z;?zZS4? zhDi(c@o1{>1XW!n zo>r>y+tk|P6K2P@)p!qx)wLw>@vxF-w}B#1G%oG=%uBn@`A$+y=)pEx*$tz{^=>Ic zxFWkO1WBQkXJqo0PyCH`baN#gy+sNIep67wyB=S3~B>F9uiujMMCj^sE>&l%9Q zDAYS$i(>gysf}WYmUgo0#j2uG_hH>LGqT?Q=IH+~srLURtJ6PgAs7)et#VtH3n^WN zb{vhgtTRfe2zC&xwL!t=Fq5n_+sLSLkyp{WN$ek*P4 z4rFIB&1tP2<%HxH5Ks{iQSmDPl_#m;!IW?5O;sOlL^u1uS&;eM&!F2mkI3@FffeOPmyB;0aA+XiRu1 zr2ZO8>A4n@mr5wRI`GwvPCJ3zopLiJ?PqK&bbS8c8$4b{%9bgw8s^|;_^6d2?I1Fk zbk40W0qw7MtEr4P{pf?;{5-8t{{>XF)O%WV#AyQD&i)hLY;7LzAXvU#>umbfn=encd}S8*heOw7`{$_l*D^JrSdd!% z{?X8#cBP~N%p5<9Kc5%IbaxQ#c(T2-F0rwrpa9F-vSj1$@DFjbZ0| z>*jn#uJ2uwLn5oYuu~Z8&0fNVN@cBT>{+}}f%2ggENVSwmLtUak`63oc?CjtKc8|| zp0hG0%;%+Ief!kiua*gc{xBrU(AvO)*>4-QB1G+MA?sLkFF24yJoeC_P3jvWa|?$2 z?IBrYQ1B(`Oe`Zt9o|S(pJnBJHluRL@*S8Chk}KVE_`-HJK?KRwC<$_VJ7z-!K-QJ zL>BPv&=4+`vzcle>U@cBZTd9}RH*s=+}TCj_jrGX@&g(95phq8FN0rRWWofyAa67T zlqcvtJxSwzE@v0M#t-FCN zKUD~w(ZqegJRp-0&-Gv)9s$pFSU@jfl)0vRA_ym*hc*l03gv33+8(x&Li7p`osij` zFgP!#c&jr$_zRYj_yjNp|`MS$b3qd@9OC18gCT1@ktGe>@DX zij9H?76;~vyL@#iFL|bs6KCoRT@K9WNdS( zR6%t6Ax*O--a3wUY39`uE$17B#{t3V&jOVOa+)|E}hCsr40Z}SAXmayv6^3c8)TkVKl)ozcIA} zS8yobBUilS0qwF5!vIv%#I~epEY)R&^35#A*R&8CKX|+tC#z%qE6A|er&dkw6c2?w z;J3ku(&DtX(g9Hwecl1_i`4$;BjH8_;@yA9$T^#3+kRpFh0tL6;P|Sc^Mf(3_u!-L zt6u$L!kvv7V076y87fT*dv_Xebwi)W=Sdc*ml~Qa&?B|C%A=j~J{Q{IWO>Uxia>{( zuFrcj87fty8_a-a=3mCkaA#NAGD*5ib&0n4v((|OGpDGTZX#55%OY96BT3KNt4IQm z{Z(h>0?A>p4RP)|k!}ZL?{gDoOWdbW`Z@3MN8~;0IyP8ZB!PpV8uOz=59cI}f0jZ9m&WQ(>j~pz^1@8L^wEe(!R#S__=9)?zGc*IC-#q(d%9_e(VK0}&aY~)0{N8^S#bMoO@1*f6!6Ncm zoa2*TX>MlXFlm$K>D7xgCL#<>jQIS@G#l)_)EyT6w*DN2K< z$BwK~7cqTRJ*m#cWQK>!^$Ypv*&kxRJKx&8G)#r-gM?e5g?o&9eeTgSqi@%fjVu~Y z&iCjGW+YyT`N;s2^RN>%kV`2L^TDgL;W~nTisv|CWW)-ha7e-ZT*Mqk0;6OVH3%Kf ze*mGRux)Qy=%(663)|FV2cV9<;HBpy!H}bqlNe(wt8iPg<_y#DZD?f;-p*HnF@mm= z)HX$*p%3NCT%^qFGf`qZx#+i4IwB1Qn9Jt3SE!7l~^2U@dFSsQmQX zaqENojtZA&FjGlsJHgg{uiS7|jV?9qv6l+r4_U!S!$PK?Zm~p#f6(ZX|9uPGbU(<` zpU{&ZO!hWYCU>A-@xcr3oso$7m&k%o=4?astUUF%3i4O+K-%*T2jZ z&lBzby4YUqvE4-Rrl0wlkg5-ox5Tlo>8mrlvwgK@b!}b?#biHoRJCl~*KLW;a@YBm zS!YA8v-3wE>g$_!IXu zgXwUQp38%GFfKfJU%?kUh4l}E`%8~ZrG$AD-TYE$g51+oe!Ue&Zn}y#U3_@6r9uuB zJ7wnU9!~QQGG(D`RSyl zD!M-3Sco3)?fXF{W^Tad7s1eLW4#D-6x5?$XcAW#LgVOG#zC#3%V7IPw)D!zg|=mxD7M z;sVU-rAOJ44D8v?J|dGu+^I~XI6jI#{+dpO%x*?YOSx?tAs2lQ>LF&_p>FFxzu8{! zi;X89@|9YrqAxWjV5!rgA!FMXPva2wU(mT79&-W6?5fYte>;#3kMroNWCaMD;O39c zzC>63-n_+edtt!V|L=)jAaj9J7vOdHlA(*Yka8TdyHguD0}f(jTMMpL_;DI_s?gPpZXW!;$$0?8_kZX&Ev0x4*dn=*TuF% zrW;*8iGo-)*2B7`4Re=F9J+2jM9>STqTO6bxo2hVJDVlB-7ybKw|0g^6@tJx1GQQY z?nBPR&7*x*7QC!@D2GCA<`%4!Ytas7Y@ z&M=CAuBz|%4^qlu$eU~r_(8Qbm3eSqke7zS1$_zEgBKVf`)};V`N1~h+Q=ojzx}?v zKbamgm(_V^{_5?|Hy7f{bbQ$K4aR6TG@1W^;%j`;4Ch0`x+C{&A4xXEbDayzAwdtO z(qUQUs*vVyE7=CKpJhivyhd7UPg5*JEF2{}!7pCIaW+&9-T0kTr6TMNW8Xv+_uWHB zV{^I6$Hm%82D~B;-0%!P7eh{a&c&UOtV**z<+78?@cX=mZNbYTghF|}`=Xf$z1#k8 zPTd_1bUo84f5#X-5HpvD%lB=b%Xv`MaF;K0~>RWI{2`W zGLatPa^XgV36$-c7ch*w*iXcXO=oZ}k(uksaPwaHp1Qw+Z~SEa{NkAyVpxT1U(q2K zCV^0pMPUi(Wqd5iS%0r(+Fb;O2v@4X8(CPvM1i)Oe!+eDOKE2#K|Ocj^6vMgHYg>0 z1^-E`to(Ii`Mx0O!oI zJeKXt$Ia)Ua5E4w=tv@= zU0#zvTGa~Yk6p{&QV5VCv$4*(8N>EvkFHQG9_l-$MRGSmicZbp!x28{h;yxht=>Dj zk^1AN81#I6a3H;hE_3i>Ns-b)rb7u-u^Q^9c~zwj!Wd5}K-+_VYi2`3ilB=hpgqQa z{G&*faNdVhIJFZ}iZxJHqb5xd#$+nnKuOL7Gy`?*fnKYDN}fz*u)j*UvxWnKW+V@q z0S@dzX%(AZpd%&F(Opnl%Agrk))2(WnI~cl-Ha<8sRd0bK{unIinY)sqbc=-?Px0d z)=zx2Aj+gcPPH{jnhKgsIrX*G%o27kKuh8vkNO{%GC$y?($y3c4mi-HG2;eQ1ANmH zEi0g2%>HLBDT7vkr}qIpGQ8AkgSe~y=ZNos0}q)|T96zvk>OKEGR=xf0Q6 ztj5T(wo2!#>@!j%1kg_!WEg1^r_o9oyii-S06Y~2=r{yDK21pkCJvUXuSr2=!`En~ zk{(&A<*ThVWJV6MMv)h;%4n&@0DMdRgwcf^I%>t7XWx0$*2+iP?vJz;`t6zkUa`n* z1kqFkj)cQyasoy?XfsPv217NhE|@pwISa9#XUo-P1=_pBFCK=YuK>O&e!KTKEVbir zOYl@~s?o^ks~KZWojxOu>T8;SaUZ0P{@_x{MWJ5`)-L_W0PYAlD9V&sF`x-nR$G)i z*r-fVUo+pBf9dD>Ch>9;$#Ah3ACa-@LX%lEA|s_eCwqV)1^DJcBSI!eA98K1sl8VR zU1(f=6M+j{ntZAtVUTN4p)4*d%X2ZyTK;_Q`1!MMM zDT(f!X4{=#jhT1pj>P3fyq1RA95}HKv;IaMHFFS1)5U}9W>FZhFE8jxNnRpYD{T+aK zL|^wFzlZod57zfs8R4b;-Q~_PRf@*-KYN3swqf^!2Y}lg?WJhUp>T7VQtPKbb@a;+z^cnG>$s|^H zNQZXYJDe>45Z{)&(>EkYlX;h>E#!SqAs+y1pJm6!GjW~^FjL{01Rc4RsY7jLK<)Ks zvu;?Lt6dw!2g)2J4t6-U-c9z~C8RlPO}sQ!3sHA~5Bi9Y+m7a}^mKSKW(w1s#U&bX zd^?~!IQg?Vxc~}^}g$-#PoQ3y`Lf&NP$q;^AVfO!+Xs`kKCQ&hA6p)4}JRt^_7T=55j zWXCY_eu*|NaKovVN0OkHcaI=N&C5Fw|D$b!>P9Id)cbQSh(hQ2UF2mq5Im3}M+i%g zsEXC~B7%+cDtf7s?;6>(;Hyfs@Q5lX?P7?k%b;Yy97AM?qZBw=Sbld?AR=z{?)uQu zIo|K2(A{$yww!Mo$^sIol^1;FG1-2zo~z-e&J)hVL_+4t9zDV<-J^cwQd^2Q2 z@Z~BxjAPD3JZMvL#k?MVc19Gk-e_XwzS&2~@j<(-SqdT!zcd3unhYbFsm;W-gBDHAP?9b1{MqkOA! zz$ALLW8!Q>L|6C_f7YId>#3AX*Drl+iPEMM9#EBt74;I)P}nVS%$ZZ8)Jt-4%!3J} z9=q?J6M{9Ru5d6ZE%y^-{vOYQ#cOmQL^BSZV0|u#4wXYb7WWH7EY;D)U>J6?tpfkm z<|m9({a%p_RhEOqDRk?*H{M$W?yUN#W^d1*fIx&%&BV1yyV4=-pzittR?&RSfl64q7G*9}9>$*bYSc574fb6kok(a^m2RhM0*hf6lU(Yj zoOE|)xpLId&Q@)i#};l(0jl7Szkk%TXJ@nSh$d|iUCw6XMZlkXGF$n9IY!77Ewl)V zatsT+Eluar*VhxN=L!$)@*0B$b;53KsbA%4Y|{a$PpIC0k>1%0H)6qjgrZ2)cMuyN zp}oxB5mb=M}M(5(B8RW=+ z`3RN2-`c@(>V<@s*|Q5u(;3EREd)^uX*T|BRwz*R3qN)D%Xf}yc3R*^ma!=?0BvgR zioe;J`|1d>Gzr*nxi+IJ@3%64l(rd9mFRd13=`OjTP+wevh?u4uTvm7fV3A`b7vgESL&su2NP6tTaA^q_&aACHx&6cqQMdS z5A+xCo+q1%o~S{I;Dg7`2b{1@HmGGVy12C9U?w&V_R4!p&L=8q7%P5AE}H(zG8 zsTLXu^?K-T<-JMZY|)L-$`+ki)1Lv za)N`|1grb}SW73Ab?zZ`SxMf>`1~T*kCr@n9}EHI0HfBW#@#!7!JU>H*tQ&?~ zx0imctl0ilHvk{#<*xPL$QF}uHM##AN? zlE@FtMlJ1LaO8Ld3HQ9a*y7x84*?Y&^C8Dob1gYM9&$ylRgA_2EadqqUo0BtIw?xO zKx^Jf{nhBG0rj=66Y6J}d&ThD>S^2w82Z_#OfJ#CL_Dl)j6cjPYpyw}x#M=)fY1Z* zp0zFW98HCsuI-mq_|;0MG6#mmZn^^X%RYx=>(h=8 zQaIn244jbvgqfw0M_--Gzm8xSPM42}m3%o6Xl@Ui*%by}aI@$8S+LeGxX*Rc!n|3_ z3~d6xhpKcON$PU#E~Iy>42Xpkr2ng|s`n7g_qpi(-^9*Ff0{WKh#|K3A5 z40Dr2x0ys9IU&2xFu6#613?`#zUis#g-R9H)Lq=_dS=5CuBo7|^Y$c&M+hV3`%1ZD z;%X?21y-yAcF9(?<}tQ1}yvR57SxNCr75f^ze=PkB*KIA+L>oAp4r!aMBuC zWpiSur={6p$+SS;z~}a#f1Xq7#~zm!M3SMZ{XB_1>?6cnS*N3q%g7(&c}(E20O6|s z5JEIK%AJ2;4Yg>h1g-}bcYvqtgQ%7ACDAed48H~(#tdPGOs}X2GL>!j%%v}b=rmhg7p!+rDth{bwPuA7 zhDfMn0DsH1p7SCc+&Mpyf(oa$mkYbd8?Zeq>mh~hq(o4f2eGh3&C$))98wLv8nzv; zy0`yboQiVsRTY8Dow)FF{utHOT`FMzhAQ^PE8>`^t%zGMrTp0948nm`<%P%{C!X?P zRSHT^-aOfk)(si|b5h9FX~lp_K_{u+EYIWq<0df9h1R{8aDzUu4Hf0e$Rm}J_5JVd z_5#Ws*Uq`7S0AXmaLjofJWlRhko_$UF)0H?c765JUm{h~{egTmlcfT;>f}s&UjO_y z7|qWUvo4-1Ucr#Ndp@`W=CD3Kki^PoNZVEyu6)SN4<$tkbKx{ZWASMKQvCW3nBlH2 zbq%SVnI%8lNWg&--#ye0NOWBY7zFJ133bGU~6OU%yDZc2+Aeeg>&?pd57?M+O{s zGQfxzvYyC{9uu{uL+o=Uf$nRShW$XuDgAi*2V+Z%t&{4?Mh2sc zAPUF7Yn26-;?e7o*9TfiS2oFG)`$Zqq-QYRMpA`DdCdB6<(uvd%woQmgzD`2W3wP` z@CI9~%a_$U_GFf{il3bhl~jn!jXPM0$i92+kEDF}g6yB?9{(!!8;E|9F4g$z%0XE! zCpV|!R^v`6#8i#zF>|Nb!Z4jLm$gL@f9In_f%4W|azz7ZDtA{F{rgb4U>-;aquToZ z_15Da?C%i#<97kl359;BmYNPe;-Grrt@r(V3ofj9A(v{w9ED>pUWB{T?=c!H={!CL zq3j#mGiK>)cP)r=DVA_^id`2!b+1qXQg4KWs#x19+1B=iWd4h;v;Ox2KR5Hz z^SL*BaJ19%Oc&?Ul31y28 z)(mVA$zk1hbsMXb-<`pvqBCJ5A7it>21E3UULba!XkF#%cv$(eb2FE}1K(qxO+z4( z7=8*{6=+I%)Y_HlA&^(Hjd*E=V*`6u{fof>F!fRw9*k4z1?fV;O^_UTpDAC?!fmXN zsal2D+eV~0({c0)7q<5dm`_y{ck<{YAj;S(!N2`pEY)(7qPz8jBGd9FF}94)V+)WU z7h}uqf>y-&m3@9Q5HE|6i@GqWYau0UrX;QeNX>qaF5nc8**W@J8^MA-O*eL{l9Ww{ zm|$+$U;?G=_bhd|``JW93OwK9Dq@Rc*7JZjZ-6jX8dK=|)ddoMg}TAn*_QCGy1CTc z#T~|!*&EOY>Pjcz=Rl9{WE*?;Oll76mGud7Q4yZP*YE5Pn5Yo}8{)D5n@L z`xxer{e*pB*Xyj6kM}7rnnTt}+p*;n*vkOLOGnuUq#?ya=R>ennkPT=nd9SR@whzRj7!mrLy-HIP`ps&DtVKYH=d2qQ&`}5+^Jn z2OItu`4!k)B2q~?MgH#EdrzGvqw07`_51~~-)sVfQrBzd>(6O?KvLl&J?!}&q!Tj# zA^7zqef9@KVZ0q@YIMg^8*@?3jO8UWa&x4LsCsFZ2`P;`QG>sAgO|@5$p{YSdMlTu zDb%wY4v9jW!z^leUvmx=f46~h8&R&Ve2w}%7(oWIU>7a$p~$s$L0G!L@bg= zo0ka$8y!c6f7A^#xUCi`l4pI0454GWUOJF0w66(K>XsU2tzR{EQ`i&A<^DTkB{n`v z9XGD0D})n^6a#SwKOj@-n)nljmAk}35{Zb$w1T~mIqMK2STaA5ekMb+bmBB*GU6PK zi3&ioOU15JP@AYaDu8>R0(r~PV2%iforR!`d-Pp?qAo2;?;fx2!+C|4!~)NQsi`_P-AnPH_se zk!faaIk28YExsVsEtQ-3Ax4VY7~o$0jCBW!uw`tp5ZRKh9MNB$y-HhYbpg8Vvx@Mg zTpNGcMO-+E0xisI<-_z4(*~7N5VHrC(R^*uRR&8Magmz>@xSC#H@c5FAL9uX^uqRH zRruphD2c^DN*i^(`CkS{z_tjm7;uu(95{)GvL9WY4@Hpj#4;YBHDW~~w)Hj6>LPc` z_y2>S^psxy4AEs?oSz~#ca-6t5Vkr7fO%j=R4A0TE>7{&MSRpIC}_!LLX6M9e;fq) zfCfLFcYJ@?4a2%vA8O41T5e5(_&X!(K%PwQZkB}LIpXG+>4^;g_7B+cX`p50|1|cU zVNI^vw$fVyC`b*(g#t>4&;)|W61O6vqSAYd)Bw_}lmOBMl+cSTK@>s@5IP}{Pz6K? zC6s`GiV$Le&^dVSzWY9RpJ$)Fe|^8cU-NzE9AnHm=Bu;%{B5&J1IOHFs!P+0$p)77 z4h6U*dPP`_-t&isheEIf+x5&_cn@#H${r~6t4NlP`R*BUg`2Ai=A5$_4wA5+QJhj^ z+zJWj95E6$Fg?$^o@Ya6F_o52Zy}i^Ir@b=OY<>Av?0q;Z!iv9%;b>mf=oB~ z;ycx3?Z5&(;i>Un_k)k2#DEo$pLK81>Nnp`I0v=sRmvC%Yw`c6;j}6u$?E1SRTlHi z*62Kcu0Sv6t+6tAZ7Z4_&LuQiKHNRbtJ^~1FD?j~Ni1FgI3B(@u@VR2cmq4d4m~bM zX|cAn`iRQN)epo=E@wOfD<>^dJHDSypw70uV?a%ORPxkW*WyRnG2<~GA;A*R?_Zve zCYQVw|Mxo5nUS*m76Hq3*rHTbqaAe>#h?i;I{=D(?=b*JP29WX_ppy-;jx}8S1f1! za+P>ntc3&hM0m8v%rtS82nUy)zdg8x?&ndY^9^LljaFf+L}=EJ%C@^n<>GI`DIPav z<^yvcD)4u5@IV}Opk5u@MW<}2&dRP>%PRHc?sg~&DTifTsQ_mCwzqN(4@(v9VU!Rv zEzf%k;2aHTaf01BAnC;oMwHV9t-BEK6~5B(C*e@-&F{CDk5_L8Mn4t99NsTd&mH!x$%g_^ufS*95&mvXV*8guiCu<|C!0`IXyEhcMy&)% zB%43fZANgGb*AwW7_r(^S{t~JMBK=Glt$BtxKu`k9PL=iKh$|8Xko?xoao9@W6iZ+s{p^N6l|7!U8P^Od6qd$lEJfKzjr?^$5j6)pmL-RLS{;MVrQ-r^%oC4Q^3RZTx*oQE zO>ur<#Kr7KT1dLi`G|wDGiG!IeKc45j0ZXm z5}M`{VqWQ0ICIB55h;7CCWQOeAuFUHMR2_93W;|?S&L++&MtJ)qsGhg>2r0HT>Q?1 z-;22Ar#t)r)hJYC)r3QUg%$SXeJ7jlhzNO)DR5O+wI|gV);apIDdZmqMckRup=ig)@c>*?L&m~g` zDi^A(IaBaD`p_*=&*hI~j(q+B9F}dCkI132qP?e)~c|=u4q_ zn!zWeh7GyE`%vvn`+6>v67!ugONZA0`}X|IwF=pe!#ZxE*HliUCif-?V5XmXB8V~r zk;lYt$82oh**pA%l*`@|a5xCGWtUp(TG&u+*pu&E zB|fUUC>9jlg7|og7;x$r=&aro8QQ0Ty; zd(wOuBcSd+X|6c@WDpS*fUAJ=o^0XIFeATiCkskS^fq?0U&?;TNBjY7;J7ut-!8Q3 zqd`4+In6=&ByT=>$3gI=pMm^la^w;b$>MfZON2tvdSf~Nxx$H6m+v< zuCc@&M~PVdWj9E4f1j!WeZ$w9Dq8u|7i!H+CU)*A+q;k1XG6r3Hm>-`<|c0}v_AYO zi*0NUnWGjrpx*B9++H~IUwWS`suaJ=mJ|o_ zGMe;csQSR1;xOh5Uh{2+!Gy;(b5p{YtkSHVd_Q z|2N5wpM%Q81~UStID1KCM`470uAf>Qeps6#J9^VaNJ&!pqkS^6x(moApM?AN+Fta& z0B+*X3n@dN;+MtuZNm%>HoBMRIV_Aj(|Zbgwkb?1^8kyRqtao+1PI;C&S`)a^o4lp z`VuTP_!kACH@H@-^ImLiAcM7RHgKcI^`msJ_Z73Cnf~G`>v-z;t90>*&sQ}%U44~C z-GnS|n8Y9JC|KU0vgP*Of+hcFt|i0d*dJ}VXXvM33vWZu?)w~hxpT-C3C&$-xqM=W<}hTxFno;SMnROFVY z?1l5-3`rFzYqB72UWZ`{xpu#e6;*OMck2TDYMsp4UX+PETRJ&`9<`q%vOUW`g!Phk z>^o;01dZ~H>c@%&g?p3{jIc@p#d|oByfLVCCgx$s!4|Tg zsmzAg=`)wy{UdJhb7soX@zMU+X2&{h{_xq_VTbTCgjHDdfln`BFu|nH?DuptLL+za z_qW>z8ZGf#pUZZWKVTlXG$f-f*P-J1qbCbsb3*SNLM%w5XM3 zr`RV%AAIvK6F;;(e{j2pn|svyOz&365fP+8G2e>Z%`TzS(ha>tEa8d-Hn~k)L;I-$ zM9I>jBAZ%Vr`+=3J{DjEZK)oH7Y-I>`K)oGW0PuHgfYsqVI_z;-I%TuWLG~Af6IIc)zbf4h$=Lb7CE!X zFV>Ul?8jE@$6XC$FHLLff7^B1xrFZFMY??~(pz;bPsLhWUZ+Er$CcQbs^=wt6cUV< zFT_U8v>ryVG?BRgNnS^Xv!ti43Ncy%$+NYW^DY?-j`nGD>z?BubV(Q-&*DVnyrNUr z=uxz=34svP$BFX^!^>G07kx)23ZsM*?|pY>*a|%?iWVloB&>!Upz^P~YXS`LLK;N~ zy=_i)d}mXJ(O;$Iu;#HpMaLhA;X^9(j4&_QR2FmHg1nm6RHiI7;;QOsN6Ri~+oK>J zVpB(`T>jdcaksVh2n(yw4`I?wd(MWAm$&0xjk;89M>;@4n)n>P+z54WN2&4V#|k*q%s=_=jP4qHHV-m z+Vtn?D=-EIz;r0k2sgj?ScB(0Him9fFJhjM_3Kh5f8!g>4&(m0;f0guSBVz~_ZRy+ zBu?qeU)BD;B=T3mnE&obFOM?b{(X&7^?5+QgQk&cuaBF6-+1izO0G!oIZ zTuJ=;G<&f5?qjsy@DA)mTK`d^CV;an#U8d>PzvaooOp`yrzzQ~^185w;NAcpyM&Oz zVKypOVf?fAlO*f(COMUVnu&aX8-D*9x+a$oXgX$ME;VtmLImZ}5w2s6r)jOhVG4!z zi}hMvHN<`@v3yd!qQfI&4nR3*=U;?dXmBEI`*F(y`uN3}*6Gm#U${||V}}i6AUJ1v zW*$2^7kfGN-Rm6Cy6#~=Pnh=~O;z>vZ1Rr0cdi zF5c9VLm%guPogjR{|i6D&az6S+#QWt~?j^ z=Sv_qx{ul(OcJL(#JoQ&>#Y9?9HB)9WjF?HdO%*i%5;BkkP16~3x`pGA3HEK&X|sW z5gJY-*zZI~(tsOzZB)&3A;kJ(lXmJ-RORG>@`L!ORXeVjKi{~@EFNGzT<;b2ZNdEH zu*TkPRCcTP!vvR&5E8}!+eTyC6$dV#?bZm zYa1By)z8g3QY>m$Tb{t2$y+Cy;h*oKm-6u>xc{7R; zGKAF_H<>f0?*+BQkjme-IUrvr*$<)Uk%P%UZ^3TUB3TgypK8d3(zfbXA^}!zSDWOf zhCD459R-QVMLP67+w1uXF(eWLe?T`dE~gH#quk@a+jSnT$x()V*DKIxc;ET(=RSMm zd0MXuG>`(CE#r_%i+6`QjXp;eZ6GDm8yNsrkzddvNOS^SGfbW4xUH&o4_{@!rBMZU z9DgDjWhiDxbz3>$CzZXMVQ1~qJEbKMq^lm$l9 z)j7JRK=L7M*>8VKPjLuGXlm+U2wa$J7H(;r0q1 zC0B@=f`~HSlorUh)pc3@@n3Hr$p`LSfEQ!o2KG(ywB-k4MdAXcU0u`)9@oXQ6;T2@ z0R%pD2V`E$!Aq2DR{3YlsosAQEcdj=uHJX`{Tg=Ak8|O%RqutXJ9swtdfm{!>OEAg zE(S6Xsbp`sySr;P%uQ0?ih7q1>#2hJ!k6MW=;SAi&7PLuL{!Q$yXrT1Ed1R1f|1ML zSJj}2Gazdp!+A361d1%mvfV73zhrWmI2sSiInl+m8%#KxD`)K(4V0MaAUx!!wi z>9+w05t>?Xz?l$K3y;Ur1wlV*&*<)E^ z6<1wi$YAGYTFA5%BR+HT>y=*IqO)=9i@pH#_&K z`>#H|zie~L5+8IdmY^sNN>qH{VF|Du$xqmMr=x@G0raeW*)UIgyUW_ zeG~pa)#x87@1M&RHnmsa%mY0FX=01?&>c~HfTHj`EgvRgah`qk z3O6fR+E_pA#cr&Fz;IQFjCL6l5x96`h`;@PR?OrPxZ9(l4lRDlP;Jnk#K*h z7fZRr9KjU(d(yK zhq`(aY5x~NpsZnF>UC~gBljle&@1la^?r)QiJn&eDV>eXLByxy42$54#7|4M-8wCa zTq-7r954Ey=DD z)K|`|heE#^%5XNe^b7+c8*OB(r`=tCz85S<(%uxsS=RM!J$TWJTQ(ox6Ca(Rc5=Bk z@7WA0YYRAn6mVoG$WH#K`rp&xzea)$Q6-8DZ`G?Uflce;c~{Op#h+5ETqnWd>2!)`1an(+qlmL2 zC442srD5M>DNbeQJw6v0!9lxq{FvFg9MsP5tgjs1QjK&fWkN2M>gxUg!VLLR>4>OarWB-eEu zrWR8|o>^SK~Bd{Ikq($wC%Dcl0UA51@isN^NX?)o5 zpzO7&y;7n=weB#?uX|3R>kZ>%m;Z!PkEJ!vZ|%b&JL!JA55JYU3PK6eeeUbm&{eFp z?6SZ8&gI6}G_R|1A7voSrnPp$#7GF7 zt*xDD3`*1Vg%g@@i`jA^L>EJC!&A9lhJ|=<7MTQ7fM^p}fv6%?PZ29GPRriIKiUE- zt<`~E(-UFmdM-L{O|82G?bUj>bGn+xJuZ_tSxLF4D#&q$*Oxx){T;Xek4!0`=5+Mg zP`6vmG6o>B8!8wf^wY(1U3bbFKR_^GlwL=Jty~09UIO zz<@;`P=piQEy6DaT;31Tz8>Szci*o*eHUXHsk%h}CT&WlhU8k=9%SvtCQ~+kZL+(d zpb!3CP!r@wsJ5UpJ#QzjwhI!r9aI@u%$n;d4tpe|A1@M-ntoCb;4;qd#gRp5XC%$q zR}+zkS)@FWAKA!SfN8`&RU{Ir#yq5xALW2>vepThXs3-h59y^8trjshwXdtgj3dZb z&`puM?8HzuG3>_G>gXhYO<5-o)=tjwO}Vc-qiU5eXGgU`NtAYyc$#0>+X{yQuBdEQ zgK6H#*>Ll)xpRbSJ6W;f9IL+}xbf?&xE*5*@K#Vqs{?#8guj}AW3BdtPN<{wBXMb0 zjv`C`&ag!Oxn%u6hQ)JZ4nF!Ms`;{vQg_;Sl_RIn1#aEyMkF7yewN#!QBd+HJCzC# zqeCk_1E?E=hF6KQd8B@9m4~sl>W&5a_(>}|VgnFhG~%dH+gL9suOP#NEXkxJf4L0* z5$5xJSi}?dto9Eyj}dYHiw1L=Czj&5ZXeoE@Q@Yp&E;>cCHPT+>-0Gfx>g^2)M2?M zmFc+mrArX4wdtzSbS#?|?wC+AyY1~~g}2qbDF0xz*pCMriIa zGJNsdTvBx&Db$DEDb+mTm|s#lSoa\n", + "**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)
\n", "**Date created:** 2022/02/24
\n", - "**Last modified:** 2022/03/01
\n", + "**Last modified:** 2022/10/15
\n", "**Description:** A minimal implementation of ShiftViT." ] }, @@ -39,12 +39,19 @@ "In this example, we minimally implement the paper with close alignement to the author's\n", "[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).\n", "\n", - "This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can\n", - "be installed using the following command:\n", - "\n", - "```shell\n", - "pip install -qq -U tensorflow-addons\n", - "```" + "This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can\n", + "be installed using the following command:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "!pip install -qq -U tensorflow-addons" ] }, { @@ -70,9 +77,11 @@ "import tensorflow as tf\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", - "\n", "import tensorflow_addons as tfa\n", "\n", + "import pathlib\n", + "import glob\n", + "\n", "# Setting seed for reproducibiltiy\n", "SEED = 42\n", "keras.utils.set_random_seed(SEED)" @@ -128,6 +137,21 @@ " # TRAINING\n", " epochs = 100\n", "\n", + " # INFERENCE\n", + " label_map = {\n", + " 0: \"airplane\",\n", + " 1: \"automobile\",\n", + " 2: \"bird\",\n", + " 3: \"cat\",\n", + " 4: \"deer\",\n", + " 5: \"dog\",\n", + " 6: \"frog\",\n", + " 7: \"horse\",\n", + " 8: \"ship\",\n", + " 9: \"truck\",\n", + " }\n", + " tf_ds_batch_size = 20\n", + "\n", "\n", "config = Config()" ] @@ -284,7 +308,7 @@ "source": [ "#### The MLP block\n", "\n", - "The MLP block is intended to be a stack of densely-connected layers.s" + "The MLP block is intended to be a stack of densely-connected layers" ] }, { @@ -315,7 +339,10 @@ "\n", " self.mlp = keras.Sequential(\n", " [\n", - " layers.Dense(units=initial_filters, activation=tf.nn.gelu,),\n", + " layers.Dense(\n", + " units=initial_filters,\n", + " activation=tf.nn.gelu,\n", + " ),\n", " layers.Dropout(rate=self.mlp_dropout_rate),\n", " layers.Dense(units=input_channels),\n", " layers.Dropout(rate=self.mlp_dropout_rate),\n", @@ -382,7 +409,7 @@ "source": [ "#### Block\n", "\n", - "The most important operation in this paper is the **shift opperation**. In this section,\n", + "The most important operation in this paper is the **shift operation**. In this section,\n", "we describe the shift operation and compare it with its original implementation provided\n", "by the authors.\n", "\n", @@ -693,6 +720,24 @@ " if self.is_merge:\n", " x = self.patch_merge(x)\n", " return x\n", + "\n", + " # Since this is a custom layer, we need to overwrite get_config()\n", + " # so that model can be easily saved & loaded after training\n", + " def get_config(self):\n", + " config = super().get_config()\n", + " config.update(\n", + " {\n", + " \"epsilon\": self.epsilon,\n", + " \"mlp_dropout_rate\": self.mlp_dropout_rate,\n", + " \"num_shift_blocks\": self.num_shift_blocks,\n", + " \"stochastic_depth_rate\": self.stochastic_depth_rate,\n", + " \"is_merge\": self.is_merge,\n", + " \"num_div\": self.num_div,\n", + " \"shift_pixel\": self.shift_pixel,\n", + " \"mlp_expand_ratio\": self.mlp_expand_ratio,\n", + " }\n", + " )\n", + " return config\n", "" ] }, @@ -780,6 +825,8 @@ " )\n", " self.global_avg_pool = layers.GlobalAveragePooling2D()\n", "\n", + " self.classifier = layers.Dense(config.num_classes)\n", + "\n", " def get_config(self):\n", " config = super().get_config()\n", " config.update(\n", @@ -788,6 +835,7 @@ " \"patch_projection\": self.patch_projection,\n", " \"stages\": self.stages,\n", " \"global_avg_pool\": self.global_avg_pool,\n", + " \"classifier\": self.classifier,\n", " }\n", " )\n", " return config\n", @@ -807,7 +855,8 @@ " x = stage(x, training=training)\n", "\n", " # Get the logits.\n", - " logits = self.global_avg_pool(x)\n", + " x = self.global_avg_pool(x)\n", + " logits = self.classifier(x)\n", "\n", " # Calculate the loss and return it.\n", " total_loss = self.compiled_loss(labels, logits)\n", @@ -824,6 +873,7 @@ " self.data_augmentation.trainable_variables,\n", " self.patch_projection.trainable_variables,\n", " self.global_avg_pool.trainable_variables,\n", + " self.classifier.trainable_variables,\n", " ]\n", " train_vars = train_vars + [stage.trainable_variables for stage in self.stages]\n", "\n", @@ -845,6 +895,15 @@ " # Update the metrics\n", " self.compiled_metrics.update_state(labels, logits)\n", " return {m.name: m.result() for m in self.metrics}\n", + "\n", + " def call(self, images):\n", + " augmented_images = self.data_augmentation(images)\n", + " x = self.patch_projection(augmented_images)\n", + " for stage in self.stages:\n", + " x = stage(x, training=False)\n", + " x = self.global_avg_pool(x)\n", + " logits = self.classifier(x)\n", + " return logits\n", "" ] }, @@ -976,6 +1035,15 @@ " return tf.where(\n", " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n", " )\n", + "\n", + " def get_config(self):\n", + " config = {\n", + " \"lr_start\": self.lr_start,\n", + " \"lr_max\": self.lr_max,\n", + " \"total_steps\": self.total_steps,\n", + " \"warmup_steps\": self.warmup_steps,\n", + " }\n", + " return config\n", "" ] }, @@ -996,6 +1064,11 @@ }, "outputs": [], "source": [ + "# pass sample data to the model so that input shape is available at the time of\n", + "# saving the model\n", + "sample_ds, _ = next(iter(train_ds))\n", + "model(sample_ds, training=False)\n", + "\n", "# Get the total number of steps for training.\n", "total_steps = int((len(x_train) / config.batch_size) * config.epochs)\n", "\n", @@ -1005,7 +1078,10 @@ "\n", "# Initialize the warmupcosine schedule.\n", "scheduled_lrs = WarmUpCosine(\n", - " lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps,\n", + " lr_start=1e-5,\n", + " lr_max=1e-3,\n", + " warmup_steps=warmup_steps,\n", + " total_steps=total_steps,\n", ")\n", "\n", "# Get the optimizer.\n", @@ -1029,7 +1105,11 @@ " epochs=config.epochs,\n", " validation_data=val_ds,\n", " callbacks=[\n", - " keras.callbacks.EarlyStopping(monitor=\"val_accuracy\", patience=5, mode=\"auto\",)\n", + " keras.callbacks.EarlyStopping(\n", + " monitor=\"val_accuracy\",\n", + " patience=5,\n", + " mode=\"auto\",\n", + " )\n", " ],\n", ")\n", "\n", @@ -1041,6 +1121,211 @@ "print(f\"Top 5 test accuracy: {acc_top5*100:0.2f}%\")" ] }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Save trained model\n", + "\n", + "Since we created the model by Subclassing, we can't save the model in HDF5 format.\n", + "\n", + "It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "model.save(\"ShiftViT\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Model inference" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Download sample data for inference**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "!wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip\n", + "!unzip -q inference_set.zip" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Load saved model**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "# Custom objects are not included when the model is saved.\n", + "# At loading time, these objects need to be passed for reconstruction of the model\n", + "saved_model = tf.keras.models.load_model(\n", + " \"ShiftViT\",\n", + " custom_objects={\"WarmUpCosine\": WarmUpCosine, \"AdamW\": tfa.optimizers.AdamW},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Utility functions for inference**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def process_image(img_path):\n", + " # read image file from string path\n", + " img = tf.io.read_file(img_path)\n", + "\n", + " # decode jpeg to uint8 tensor\n", + " img = tf.io.decode_jpeg(img, channels=3)\n", + "\n", + " # resize image to match input size accepted by model\n", + " # use `method` as `nearest` to preserve dtype of input passed to `resize()`\n", + " img = tf.image.resize(\n", + " img, [config.input_shape[0], config.input_shape[1]], method=\"nearest\"\n", + " )\n", + " return img\n", + "\n", + "\n", + "def create_tf_dataset(image_dir):\n", + " data_dir = pathlib.Path(image_dir)\n", + "\n", + " # create tf.data dataset using directory of images\n", + " predict_ds = tf.data.Dataset.list_files(str(data_dir / \"*.jpg\"), shuffle=False)\n", + "\n", + " # use map to convert string paths to uint8 image tensors\n", + " # setting `num_parallel_calls' helps in processing multiple images parallely\n", + " predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO)\n", + "\n", + " # create a Prefetch Dataset for better latency & throughput\n", + " predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO)\n", + " return predict_ds\n", + "\n", + "\n", + "def predict(predict_ds):\n", + " # ShiftViT model returns logits (non-normalized predictions)\n", + " logits = saved_model.predict(predict_ds)\n", + "\n", + " # normalize predictions by calling softmax()\n", + " probabilities = tf.nn.softmax(logits)\n", + " return probabilities\n", + "\n", + "\n", + "def get_predicted_class(probabilities):\n", + " pred_label = np.argmax(probabilities)\n", + " predicted_class = config.label_map[pred_label]\n", + " return predicted_class\n", + "\n", + "\n", + "def get_confidence_scores(probabilities):\n", + " # get the indices of the probability scores sorted in descending order\n", + " labels = np.argsort(probabilities)[::-1]\n", + " confidences = {\n", + " config.label_map[label]: np.round((probabilities[label]) * 100, 2)\n", + " for label in labels\n", + " }\n", + " return confidences\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Get predictions**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "img_dir = \"inference_set\"\n", + "predict_ds = create_tf_dataset(img_dir)\n", + "probabilities = predict(predict_ds)\n", + "print(f\"probabilities: {probabilities[0]}\")\n", + "confidences = get_confidence_scores(probabilities[0])\n", + "print(confidences)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**View predictions**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(10, 10))\n", + "for images in predict_ds:\n", + " for i in range(min(6, probabilities.shape[0])):\n", + " ax = plt.subplot(3, 3, i + 1)\n", + " plt.imshow(images[i].numpy().astype(\"uint8\"))\n", + " predicted_class = get_predicted_class(probabilities[i])\n", + " plt.title(predicted_class)\n", + " plt.axis(\"off\")" + ] + }, { "cell_type": "markdown", "metadata": { @@ -1069,6 +1354,19 @@ "- A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for\n", "helping us with the Learning Rate Schedule." ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "**Example available on HuggingFace**\n", + "\n", + "| Trained Model | Demo |\n", + "| :--: | :--: |\n", + "| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) |" + ] } ], "metadata": { @@ -1100,4 +1398,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/examples/vision/md/shiftvit.md b/examples/vision/md/shiftvit.md index 6738bfb777..0b583bd1c2 100644 --- a/examples/vision/md/shiftvit.md +++ b/examples/vision/md/shiftvit.md @@ -1,8 +1,8 @@ # A Vision Transformer without Attention -**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha)
+**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)
**Date created:** 2022/02/24
-**Last modified:** 2022/03/01
+**Last modified:** 2022/10/15
**Description:** A minimal implementation of ShiftViT. @@ -30,13 +30,15 @@ operation with a shifting operation. In this example, we minimally implement the paper with close alignement to the author's [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py). -This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can +This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can be installed using the following command: -```shell -pip install -qq -U tensorflow-addons + +```python +!pip install -qq -U tensorflow-addons ``` + --- ## Setup and imports @@ -48,9 +50,11 @@ import matplotlib.pyplot as plt import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers - import tensorflow_addons as tfa +import pathlib +import glob + # Setting seed for reproducibiltiy SEED = 42 keras.utils.set_random_seed(SEED) @@ -94,6 +98,21 @@ class Config(object): # TRAINING epochs = 100 + # INFERENCE + label_map = { + 0: "airplane", + 1: "automobile", + 2: "bird", + 3: "cat", + 4: "deer", + 5: "dog", + 6: "frog", + 7: "horse", + 8: "ship", + 9: "truck", + } + tf_ds_batch_size = 20 + config = Config() ``` @@ -127,14 +146,12 @@ test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
``` +Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz +170498071/170498071 [==============================] - 3s 0us/step Training samples: 40000 Validation samples: 10000 Testing samples: 10000 -2022-03-01 03:10:21.342684: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA -To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. -2022-03-01 03:10:21.850844: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38420 MB memory: -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:61:00.0, compute capability: 8.0 - ```
--- @@ -219,7 +236,7 @@ The Shift Block as shown in Fig. 3, comprises of the following: #### The MLP block -The MLP block is intended to be a stack of densely-connected layers.s +The MLP block is intended to be a stack of densely-connected layers ```python @@ -243,7 +260,10 @@ class MLP(layers.Layer): self.mlp = keras.Sequential( [ - layers.Dense(units=initial_filters, activation=tf.nn.gelu,), + layers.Dense( + units=initial_filters, + activation=tf.nn.gelu, + ), layers.Dropout(rate=self.mlp_dropout_rate), layers.Dense(units=input_channels), layers.Dropout(rate=self.mlp_dropout_rate), @@ -291,7 +311,7 @@ class DropPath(layers.Layer): #### Block -The most important operation in this paper is the **shift opperation**. In this section, +The most important operation in this paper is the **shift operation**. In this section, we describe the shift operation and compare it with its original implementation provided by the authors. @@ -563,6 +583,24 @@ class StackedShiftBlocks(layers.Layer): x = self.patch_merge(x) return x + # Since this is a custom layer, we need to overwrite get_config() + # so that model can be easily saved & loaded after training + def get_config(self): + config = super().get_config() + config.update( + { + "epsilon": self.epsilon, + "mlp_dropout_rate": self.mlp_dropout_rate, + "num_shift_blocks": self.num_shift_blocks, + "stochastic_depth_rate": self.stochastic_depth_rate, + "is_merge": self.is_merge, + "num_div": self.num_div, + "shift_pixel": self.shift_pixel, + "mlp_expand_ratio": self.mlp_expand_ratio, + } + ) + return config + ``` --- @@ -637,6 +675,8 @@ class ShiftViTModel(keras.Model): ) self.global_avg_pool = layers.GlobalAveragePooling2D() + self.classifier = layers.Dense(config.num_classes) + def get_config(self): config = super().get_config() config.update( @@ -645,6 +685,7 @@ class ShiftViTModel(keras.Model): "patch_projection": self.patch_projection, "stages": self.stages, "global_avg_pool": self.global_avg_pool, + "classifier": self.classifier, } ) return config @@ -664,7 +705,8 @@ class ShiftViTModel(keras.Model): x = stage(x, training=training) # Get the logits. - logits = self.global_avg_pool(x) + x = self.global_avg_pool(x) + logits = self.classifier(x) # Calculate the loss and return it. total_loss = self.compiled_loss(labels, logits) @@ -681,6 +723,7 @@ class ShiftViTModel(keras.Model): self.data_augmentation.trainable_variables, self.patch_projection.trainable_variables, self.global_avg_pool.trainable_variables, + self.classifier.trainable_variables, ] train_vars = train_vars + [stage.trainable_variables for stage in self.stages] @@ -703,6 +746,15 @@ class ShiftViTModel(keras.Model): self.compiled_metrics.update_state(labels, logits) return {m.name: m.result() for m in self.metrics} + def call(self, images): + augmented_images = self.data_augmentation(images) + x = self.patch_projection(augmented_images) + for stage in self.stages: + x = stage(x, training=False) + x = self.global_avg_pool(x) + logits = self.classifier(x) + return logits + ``` --- @@ -810,6 +862,15 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): step > self.total_steps, 0.0, learning_rate, name="learning_rate" ) + def get_config(self): + config = { + "lr_start": self.lr_start, + "lr_max": self.lr_max, + "total_steps": self.total_steps, + "warmup_steps": self.warmup_steps, + } + return config + ``` --- @@ -817,6 +878,11 @@ class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): ```python +# pass sample data to the model so that input shape is available at the time of +# saving the model +sample_ds, _ = next(iter(train_ds)) +model(sample_ds, training=False) + # Get the total number of steps for training. total_steps = int((len(x_train) / config.batch_size) * config.epochs) @@ -826,7 +892,10 @@ warmup_steps = int(total_steps * warmup_epoch_percentage) # Initialize the warmupcosine schedule. scheduled_lrs = WarmUpCosine( - lr_start=1e-5, lr_max=1e-3, warmup_steps=warmup_steps, total_steps=total_steps, + lr_start=1e-5, + lr_max=1e-3, + warmup_steps=warmup_steps, + total_steps=total_steps, ) # Get the optimizer. @@ -850,7 +919,11 @@ history = model.fit( epochs=config.epochs, validation_data=val_ds, callbacks=[ - keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5, mode="auto",) + keras.callbacks.EarlyStopping( + monitor="val_accuracy", + patience=5, + mode="auto", + ) ], ) @@ -865,109 +938,249 @@ print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
``` Epoch 1/100 - -2022-03-01 03:10:41.373231: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8202 -2022-03-01 03:10:43.145958: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. - -157/157 [==============================] - 34s 84ms/step - loss: 3.2975 - accuracy: 0.1084 - top-5-accuracy: 0.4806 - val_loss: 2.1575 - val_accuracy: 0.2017 - val_top-5-accuracy: 0.7184 +157/157 [==============================] - 72s 332ms/step - loss: 2.3844 - accuracy: 0.1444 - top-5-accuracy: 0.6051 - val_loss: 2.0984 - val_accuracy: 0.2610 - val_top-5-accuracy: 0.7638 Epoch 2/100 -157/157 [==============================] - 11s 67ms/step - loss: 2.1727 - accuracy: 0.2289 - top-5-accuracy: 0.7516 - val_loss: 1.8819 - val_accuracy: 0.3182 - val_top-5-accuracy: 0.8386 +157/157 [==============================] - 49s 314ms/step - loss: 1.9457 - accuracy: 0.2893 - top-5-accuracy: 0.8103 - val_loss: 1.9459 - val_accuracy: 0.3356 - val_top-5-accuracy: 0.8614 Epoch 3/100 -157/157 [==============================] - 10s 67ms/step - loss: 1.8169 - accuracy: 0.3426 - top-5-accuracy: 0.8592 - val_loss: 1.6174 - val_accuracy: 0.4053 - val_top-5-accuracy: 0.8934 +157/157 [==============================] - 50s 316ms/step - loss: 1.7093 - accuracy: 0.3810 - top-5-accuracy: 0.8761 - val_loss: 1.5349 - val_accuracy: 0.4585 - val_top-5-accuracy: 0.9045 Epoch 4/100 -157/157 [==============================] - 10s 67ms/step - loss: 1.6215 - accuracy: 0.4092 - top-5-accuracy: 0.8983 - val_loss: 1.4239 - val_accuracy: 0.4903 - val_top-5-accuracy: 0.9216 +157/157 [==============================] - 49s 315ms/step - loss: 1.5473 - accuracy: 0.4374 - top-5-accuracy: 0.9090 - val_loss: 1.4257 - val_accuracy: 0.4862 - val_top-5-accuracy: 0.9298 Epoch 5/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.5081 - accuracy: 0.4571 - top-5-accuracy: 0.9148 - val_loss: 1.3359 - val_accuracy: 0.5161 - val_top-5-accuracy: 0.9369 +157/157 [==============================] - 50s 316ms/step - loss: 1.4316 - accuracy: 0.4816 - top-5-accuracy: 0.9243 - val_loss: 1.4032 - val_accuracy: 0.5092 - val_top-5-accuracy: 0.9362 Epoch 6/100 -157/157 [==============================] - 11s 68ms/step - loss: 1.4282 - accuracy: 0.4868 - top-5-accuracy: 0.9249 - val_loss: 1.2929 - val_accuracy: 0.5347 - val_top-5-accuracy: 0.9404 +157/157 [==============================] - 50s 316ms/step - loss: 1.3588 - accuracy: 0.5131 - top-5-accuracy: 0.9333 - val_loss: 1.2893 - val_accuracy: 0.5411 - val_top-5-accuracy: 0.9457 Epoch 7/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.3465 - accuracy: 0.5181 - top-5-accuracy: 0.9362 - val_loss: 1.2653 - val_accuracy: 0.5497 - val_top-5-accuracy: 0.9449 +157/157 [==============================] - 50s 316ms/step - loss: 1.2894 - accuracy: 0.5385 - top-5-accuracy: 0.9410 - val_loss: 1.2922 - val_accuracy: 0.5416 - val_top-5-accuracy: 0.9432 Epoch 8/100 -157/157 [==============================] - 10s 67ms/step - loss: 1.2907 - accuracy: 0.5400 - top-5-accuracy: 0.9416 - val_loss: 1.1919 - val_accuracy: 0.5753 - val_top-5-accuracy: 0.9515 +157/157 [==============================] - 49s 315ms/step - loss: 1.2388 - accuracy: 0.5568 - top-5-accuracy: 0.9468 - val_loss: 1.2100 - val_accuracy: 0.5733 - val_top-5-accuracy: 0.9545 Epoch 9/100 -157/157 [==============================] - 11s 67ms/step - loss: 1.2247 - accuracy: 0.5644 - top-5-accuracy: 0.9480 - val_loss: 1.1741 - val_accuracy: 0.5742 - val_top-5-accuracy: 0.9563 +157/157 [==============================] - 49s 315ms/step - loss: 1.2043 - accuracy: 0.5698 - top-5-accuracy: 0.9491 - val_loss: 1.2166 - val_accuracy: 0.5675 - val_top-5-accuracy: 0.9520 Epoch 10/100 -157/157 [==============================] - 11s 67ms/step - loss: 1.1983 - accuracy: 0.5760 - top-5-accuracy: 0.9505 - val_loss: 1.4545 - val_accuracy: 0.4804 - val_top-5-accuracy: 0.9198 +157/157 [==============================] - 49s 315ms/step - loss: 1.1694 - accuracy: 0.5861 - top-5-accuracy: 0.9528 - val_loss: 1.1738 - val_accuracy: 0.5883 - val_top-5-accuracy: 0.9541 Epoch 11/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.2002 - accuracy: 0.5766 - top-5-accuracy: 0.9510 - val_loss: 1.1129 - val_accuracy: 0.6055 - val_top-5-accuracy: 0.9593 +157/157 [==============================] - 50s 316ms/step - loss: 1.1290 - accuracy: 0.5994 - top-5-accuracy: 0.9575 - val_loss: 1.1161 - val_accuracy: 0.6064 - val_top-5-accuracy: 0.9618 Epoch 12/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.1309 - accuracy: 0.5990 - top-5-accuracy: 0.9575 - val_loss: 1.0369 - val_accuracy: 0.6341 - val_top-5-accuracy: 0.9638 +157/157 [==============================] - 50s 316ms/step - loss: 1.0861 - accuracy: 0.6157 - top-5-accuracy: 0.9602 - val_loss: 1.1220 - val_accuracy: 0.6133 - val_top-5-accuracy: 0.9576 Epoch 13/100 -157/157 [==============================] - 10s 66ms/step - loss: 1.0786 - accuracy: 0.6204 - top-5-accuracy: 0.9613 - val_loss: 1.0802 - val_accuracy: 0.6193 - val_top-5-accuracy: 0.9594 +157/157 [==============================] - 49s 315ms/step - loss: 1.0766 - accuracy: 0.6178 - top-5-accuracy: 0.9612 - val_loss: 1.0108 - val_accuracy: 0.6402 - val_top-5-accuracy: 0.9681 Epoch 14/100 -157/157 [==============================] - 10s 65ms/step - loss: 1.0438 - accuracy: 0.6330 - top-5-accuracy: 0.9640 - val_loss: 0.9584 - val_accuracy: 0.6596 - val_top-5-accuracy: 0.9713 +157/157 [==============================] - 49s 315ms/step - loss: 1.0179 - accuracy: 0.6416 - top-5-accuracy: 0.9658 - val_loss: 1.0196 - val_accuracy: 0.6405 - val_top-5-accuracy: 0.9667 Epoch 15/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.9957 - accuracy: 0.6496 - top-5-accuracy: 0.9684 - val_loss: 0.9530 - val_accuracy: 0.6636 - val_top-5-accuracy: 0.9712 +157/157 [==============================] - 50s 316ms/step - loss: 1.0028 - accuracy: 0.6470 - top-5-accuracy: 0.9678 - val_loss: 1.0113 - val_accuracy: 0.6415 - val_top-5-accuracy: 0.9672 Epoch 16/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.9710 - accuracy: 0.6599 - top-5-accuracy: 0.9696 - val_loss: 0.8856 - val_accuracy: 0.6863 - val_top-5-accuracy: 0.9756 +157/157 [==============================] - 50s 316ms/step - loss: 0.9613 - accuracy: 0.6611 - top-5-accuracy: 0.9710 - val_loss: 1.0516 - val_accuracy: 0.6406 - val_top-5-accuracy: 0.9596 Epoch 17/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.9316 - accuracy: 0.6706 - top-5-accuracy: 0.9721 - val_loss: 0.9919 - val_accuracy: 0.6480 - val_top-5-accuracy: 0.9671 +157/157 [==============================] - 50s 316ms/step - loss: 0.9262 - accuracy: 0.6740 - top-5-accuracy: 0.9729 - val_loss: 0.9010 - val_accuracy: 0.6844 - val_top-5-accuracy: 0.9750 Epoch 18/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.8899 - accuracy: 0.6884 - top-5-accuracy: 0.9763 - val_loss: 0.8753 - val_accuracy: 0.6949 - val_top-5-accuracy: 0.9752 +157/157 [==============================] - 50s 316ms/step - loss: 0.8768 - accuracy: 0.6916 - top-5-accuracy: 0.9769 - val_loss: 0.8862 - val_accuracy: 0.6908 - val_top-5-accuracy: 0.9767 Epoch 19/100 -157/157 [==============================] - 10s 64ms/step - loss: 0.8529 - accuracy: 0.6979 - top-5-accuracy: 0.9772 - val_loss: 0.8793 - val_accuracy: 0.6943 - val_top-5-accuracy: 0.9754 +157/157 [==============================] - 49s 315ms/step - loss: 0.8595 - accuracy: 0.6984 - top-5-accuracy: 0.9768 - val_loss: 0.8732 - val_accuracy: 0.6982 - val_top-5-accuracy: 0.9738 Epoch 20/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.8509 - accuracy: 0.7009 - top-5-accuracy: 0.9783 - val_loss: 0.8183 - val_accuracy: 0.7174 - val_top-5-accuracy: 0.9763 +157/157 [==============================] - 50s 317ms/step - loss: 0.8252 - accuracy: 0.7103 - top-5-accuracy: 0.9793 - val_loss: 0.9330 - val_accuracy: 0.6745 - val_top-5-accuracy: 0.9718 Epoch 21/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.8087 - accuracy: 0.7143 - top-5-accuracy: 0.9809 - val_loss: 0.7885 - val_accuracy: 0.7276 - val_top-5-accuracy: 0.9769 +157/157 [==============================] - 51s 322ms/step - loss: 0.8003 - accuracy: 0.7180 - top-5-accuracy: 0.9814 - val_loss: 0.8912 - val_accuracy: 0.6948 - val_top-5-accuracy: 0.9728 Epoch 22/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.8004 - accuracy: 0.7192 - top-5-accuracy: 0.9811 - val_loss: 0.7601 - val_accuracy: 0.7371 - val_top-5-accuracy: 0.9805 +157/157 [==============================] - 51s 326ms/step - loss: 0.7651 - accuracy: 0.7317 - top-5-accuracy: 0.9829 - val_loss: 0.7894 - val_accuracy: 0.7277 - val_top-5-accuracy: 0.9791 Epoch 23/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.7665 - accuracy: 0.7304 - top-5-accuracy: 0.9816 - val_loss: 0.7564 - val_accuracy: 0.7412 - val_top-5-accuracy: 0.9808 +157/157 [==============================] - 52s 328ms/step - loss: 0.7372 - accuracy: 0.7415 - top-5-accuracy: 0.9843 - val_loss: 0.7752 - val_accuracy: 0.7284 - val_top-5-accuracy: 0.9804 Epoch 24/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.7599 - accuracy: 0.7344 - top-5-accuracy: 0.9832 - val_loss: 0.7475 - val_accuracy: 0.7389 - val_top-5-accuracy: 0.9822 +157/157 [==============================] - 51s 327ms/step - loss: 0.7324 - accuracy: 0.7423 - top-5-accuracy: 0.9852 - val_loss: 0.7949 - val_accuracy: 0.7340 - val_top-5-accuracy: 0.9792 Epoch 25/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.7398 - accuracy: 0.7427 - top-5-accuracy: 0.9833 - val_loss: 0.7211 - val_accuracy: 0.7504 - val_top-5-accuracy: 0.9829 +157/157 [==============================] - 51s 323ms/step - loss: 0.7051 - accuracy: 0.7512 - top-5-accuracy: 0.9858 - val_loss: 0.7967 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9787 Epoch 26/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.7114 - accuracy: 0.7500 - top-5-accuracy: 0.9857 - val_loss: 0.7385 - val_accuracy: 0.7462 - val_top-5-accuracy: 0.9822 +157/157 [==============================] - 51s 323ms/step - loss: 0.6832 - accuracy: 0.7577 - top-5-accuracy: 0.9870 - val_loss: 0.7840 - val_accuracy: 0.7322 - val_top-5-accuracy: 0.9807 Epoch 27/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.6954 - accuracy: 0.7577 - top-5-accuracy: 0.9851 - val_loss: 0.7477 - val_accuracy: 0.7402 - val_top-5-accuracy: 0.9802 +157/157 [==============================] - 51s 322ms/step - loss: 0.6609 - accuracy: 0.7654 - top-5-accuracy: 0.9877 - val_loss: 0.7447 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9816 Epoch 28/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.6807 - accuracy: 0.7588 - top-5-accuracy: 0.9871 - val_loss: 0.7275 - val_accuracy: 0.7536 - val_top-5-accuracy: 0.9822 +157/157 [==============================] - 50s 319ms/step - loss: 0.6495 - accuracy: 0.7724 - top-5-accuracy: 0.9883 - val_loss: 0.7885 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9817 Epoch 29/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.6719 - accuracy: 0.7648 - top-5-accuracy: 0.9876 - val_loss: 0.7261 - val_accuracy: 0.7487 - val_top-5-accuracy: 0.9815 +157/157 [==============================] - 50s 317ms/step - loss: 0.6491 - accuracy: 0.7707 - top-5-accuracy: 0.9885 - val_loss: 0.7539 - val_accuracy: 0.7458 - val_top-5-accuracy: 0.9821 Epoch 30/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.6578 - accuracy: 0.7696 - top-5-accuracy: 0.9871 - val_loss: 0.6932 - val_accuracy: 0.7641 - val_top-5-accuracy: 0.9833 +157/157 [==============================] - 50s 317ms/step - loss: 0.6213 - accuracy: 0.7823 - top-5-accuracy: 0.9888 - val_loss: 0.7571 - val_accuracy: 0.7470 - val_top-5-accuracy: 0.9815 Epoch 31/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.6489 - accuracy: 0.7740 - top-5-accuracy: 0.9877 - val_loss: 0.7400 - val_accuracy: 0.7486 - val_top-5-accuracy: 0.9820 +157/157 [==============================] - 50s 318ms/step - loss: 0.5976 - accuracy: 0.7902 - top-5-accuracy: 0.9906 - val_loss: 0.7430 - val_accuracy: 0.7508 - val_top-5-accuracy: 0.9817 Epoch 32/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.6290 - accuracy: 0.7812 - top-5-accuracy: 0.9895 - val_loss: 0.6954 - val_accuracy: 0.7628 - val_top-5-accuracy: 0.9847 +157/157 [==============================] - 50s 318ms/step - loss: 0.5932 - accuracy: 0.7898 - top-5-accuracy: 0.9910 - val_loss: 0.7545 - val_accuracy: 0.7469 - val_top-5-accuracy: 0.9793 Epoch 33/100 -157/157 [==============================] - 10s 67ms/step - loss: 0.6194 - accuracy: 0.7826 - top-5-accuracy: 0.9894 - val_loss: 0.6913 - val_accuracy: 0.7619 - val_top-5-accuracy: 0.9842 +157/157 [==============================] - 50s 318ms/step - loss: 0.5977 - accuracy: 0.7850 - top-5-accuracy: 0.9913 - val_loss: 0.7200 - val_accuracy: 0.7569 - val_top-5-accuracy: 0.9830 Epoch 34/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.5917 - accuracy: 0.7930 - top-5-accuracy: 0.9902 - val_loss: 0.6879 - val_accuracy: 0.7715 - val_top-5-accuracy: 0.9831 +157/157 [==============================] - 50s 317ms/step - loss: 0.5552 - accuracy: 0.8041 - top-5-accuracy: 0.9920 - val_loss: 0.7377 - val_accuracy: 0.7552 - val_top-5-accuracy: 0.9818 Epoch 35/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.5878 - accuracy: 0.7916 - top-5-accuracy: 0.9907 - val_loss: 0.6759 - val_accuracy: 0.7720 - val_top-5-accuracy: 0.9849 +157/157 [==============================] - 50s 319ms/step - loss: 0.5509 - accuracy: 0.8056 - top-5-accuracy: 0.9921 - val_loss: 0.8125 - val_accuracy: 0.7331 - val_top-5-accuracy: 0.9782 Epoch 36/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.5713 - accuracy: 0.8004 - top-5-accuracy: 0.9913 - val_loss: 0.6920 - val_accuracy: 0.7657 - val_top-5-accuracy: 0.9841 +157/157 [==============================] - 50s 317ms/step - loss: 0.5296 - accuracy: 0.8116 - top-5-accuracy: 0.9933 - val_loss: 0.6900 - val_accuracy: 0.7680 - val_top-5-accuracy: 0.9849 Epoch 37/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.5590 - accuracy: 0.8040 - top-5-accuracy: 0.9913 - val_loss: 0.6790 - val_accuracy: 0.7718 - val_top-5-accuracy: 0.9831 +157/157 [==============================] - 50s 316ms/step - loss: 0.5151 - accuracy: 0.8170 - top-5-accuracy: 0.9941 - val_loss: 0.7275 - val_accuracy: 0.7610 - val_top-5-accuracy: 0.9841 Epoch 38/100 -157/157 [==============================] - 11s 67ms/step - loss: 0.5445 - accuracy: 0.8114 - top-5-accuracy: 0.9926 - val_loss: 0.6756 - val_accuracy: 0.7720 - val_top-5-accuracy: 0.9852 +157/157 [==============================] - 50s 317ms/step - loss: 0.5069 - accuracy: 0.8217 - top-5-accuracy: 0.9936 - val_loss: 0.7067 - val_accuracy: 0.7703 - val_top-5-accuracy: 0.9835 Epoch 39/100 -157/157 [==============================] - 11s 67ms/step - loss: 0.5292 - accuracy: 0.8155 - top-5-accuracy: 0.9930 - val_loss: 0.6578 - val_accuracy: 0.7807 - val_top-5-accuracy: 0.9845 +157/157 [==============================] - 50s 318ms/step - loss: 0.4771 - accuracy: 0.8304 - top-5-accuracy: 0.9945 - val_loss: 0.7110 - val_accuracy: 0.7668 - val_top-5-accuracy: 0.9836 Epoch 40/100 -157/157 [==============================] - 11s 68ms/step - loss: 0.5169 - accuracy: 0.8181 - top-5-accuracy: 0.9926 - val_loss: 0.6582 - val_accuracy: 0.7795 - val_top-5-accuracy: 0.9849 +157/157 [==============================] - 50s 317ms/step - loss: 0.4675 - accuracy: 0.8350 - top-5-accuracy: 0.9956 - val_loss: 0.7130 - val_accuracy: 0.7688 - val_top-5-accuracy: 0.9829 Epoch 41/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.5108 - accuracy: 0.8217 - top-5-accuracy: 0.9937 - val_loss: 0.6344 - val_accuracy: 0.7846 - val_top-5-accuracy: 0.9855 +157/157 [==============================] - 50s 319ms/step - loss: 0.4586 - accuracy: 0.8382 - top-5-accuracy: 0.9959 - val_loss: 0.7331 - val_accuracy: 0.7598 - val_top-5-accuracy: 0.9806 Epoch 42/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.5056 - accuracy: 0.8220 - top-5-accuracy: 0.9936 - val_loss: 0.6723 - val_accuracy: 0.7744 - val_top-5-accuracy: 0.9851 +157/157 [==============================] - 50s 318ms/step - loss: 0.4558 - accuracy: 0.8380 - top-5-accuracy: 0.9959 - val_loss: 0.7187 - val_accuracy: 0.7722 - val_top-5-accuracy: 0.9832 Epoch 43/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.4824 - accuracy: 0.8317 - top-5-accuracy: 0.9943 - val_loss: 0.6800 - val_accuracy: 0.7771 - val_top-5-accuracy: 0.9834 +157/157 [==============================] - 50s 320ms/step - loss: 0.4356 - accuracy: 0.8450 - top-5-accuracy: 0.9958 - val_loss: 0.7162 - val_accuracy: 0.7693 - val_top-5-accuracy: 0.9850 Epoch 44/100 -157/157 [==============================] - 10s 67ms/step - loss: 0.4719 - accuracy: 0.8339 - top-5-accuracy: 0.9938 - val_loss: 0.6742 - val_accuracy: 0.7785 - val_top-5-accuracy: 0.9840 +157/157 [==============================] - 49s 314ms/step - loss: 0.4425 - accuracy: 0.8433 - top-5-accuracy: 0.9958 - val_loss: 0.7061 - val_accuracy: 0.7698 - val_top-5-accuracy: 0.9853 Epoch 45/100 -157/157 [==============================] - 10s 65ms/step - loss: 0.4605 - accuracy: 0.8379 - top-5-accuracy: 0.9953 - val_loss: 0.6732 - val_accuracy: 0.7781 - val_top-5-accuracy: 0.9841 +157/157 [==============================] - 49s 314ms/step - loss: 0.4072 - accuracy: 0.8551 - top-5-accuracy: 0.9967 - val_loss: 0.7025 - val_accuracy: 0.7820 - val_top-5-accuracy: 0.9848 Epoch 46/100 -157/157 [==============================] - 10s 66ms/step - loss: 0.4608 - accuracy: 0.8390 - top-5-accuracy: 0.9947 - val_loss: 0.6547 - val_accuracy: 0.7846 - val_top-5-accuracy: 0.9852 +157/157 [==============================] - 49s 314ms/step - loss: 0.3865 - accuracy: 0.8644 - top-5-accuracy: 0.9970 - val_loss: 0.7178 - val_accuracy: 0.7740 - val_top-5-accuracy: 0.9844 +Epoch 47/100 +157/157 [==============================] - 49s 313ms/step - loss: 0.3718 - accuracy: 0.8694 - top-5-accuracy: 0.9973 - val_loss: 0.7216 - val_accuracy: 0.7768 - val_top-5-accuracy: 0.9828 +Epoch 48/100 +157/157 [==============================] - 49s 314ms/step - loss: 0.3733 - accuracy: 0.8673 - top-5-accuracy: 0.9970 - val_loss: 0.7440 - val_accuracy: 0.7713 - val_top-5-accuracy: 0.9841 +Epoch 49/100 +157/157 [==============================] - 49s 313ms/step - loss: 0.3531 - accuracy: 0.8741 - top-5-accuracy: 0.9979 - val_loss: 0.7220 - val_accuracy: 0.7738 - val_top-5-accuracy: 0.9848 +Epoch 50/100 +157/157 [==============================] - 49s 314ms/step - loss: 0.3502 - accuracy: 0.8738 - top-5-accuracy: 0.9980 - val_loss: 0.7245 - val_accuracy: 0.7734 - val_top-5-accuracy: 0.9836 TESTING -40/40 [==============================] - 1s 22ms/step - loss: 0.6801 - accuracy: 0.7720 - top-5-accuracy: 0.9864 -Loss: 0.68 -Top 1 test accuracy: 77.20% -Top 5 test accuracy: 98.64% +40/40 [==============================] - 2s 56ms/step - loss: 0.7336 - accuracy: 0.7638 - top-5-accuracy: 0.9855 +Loss: 0.73 +Top 1 test accuracy: 76.38% +Top 5 test accuracy: 98.55% ```
+--- +## Save trained model + +Since we created the model by Subclassing, we can't save the model in HDF5 format. + +It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well. + + +```python +model.save("ShiftViT") +``` + +--- +## Model inference + +**Download sample data for inference** + + +```python +!wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip +!unzip -q inference_set.zip +``` + +**Load saved model** + + +```python +# Custom objects are not included when the model is saved. +# At loading time, these objects need to be passed for reconstruction of the model +saved_model = tf.keras.models.load_model( + "ShiftViT", + custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW}, +) +``` + + +**Utility functions for inference** + + +```python + +def process_image(img_path): + # read image file from string path + img = tf.io.read_file(img_path) + + # decode jpeg to uint8 tensor + img = tf.io.decode_jpeg(img, channels=3) + + # resize image to match input size accepted by model + # use `method` as `nearest` to preserve dtype of input passed to `resize()` + img = tf.image.resize( + img, [config.input_shape[0], config.input_shape[1]], method="nearest" + ) + return img + + +def create_tf_dataset(image_dir): + data_dir = pathlib.Path(image_dir) + + # create tf.data dataset using directory of images + predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False) + + # use map to convert string paths to uint8 image tensors + # setting `num_parallel_calls' helps in processing multiple images parallely + predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO) + + # create a Prefetch Dataset for better latency & throughput + predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO) + return predict_ds + + +def predict(predict_ds): + # ShiftViT model returns logits (non-normalized predictions) + logits = saved_model.predict(predict_ds) + + # normalize predictions by calling softmax() + probabilities = tf.nn.softmax(logits) + return probabilities + + +def get_predicted_class(probabilities): + pred_label = np.argmax(probabilities) + predicted_class = config.label_map[pred_label] + return predicted_class + + +def get_confidence_scores(probabilities): + # get the indices of the probability scores sorted in descending order + labels = np.argsort(probabilities)[::-1] + confidences = { + config.label_map[label]: np.round((probabilities[label]) * 100, 2) + for label in labels + } + return confidences + +``` + +**Get predictions** + + +```python +img_dir = "inference_set" +predict_ds = create_tf_dataset(img_dir) +probabilities = predict(predict_ds) +print(f"probabilities: {probabilities[0]}") +confidences = get_confidence_scores(probabilities[0]) +print(confidences) +``` + +
+``` +1/1 [==============================] - 2s 2s/step +probabilities: [8.7329084e-01 1.3162658e-03 6.1781306e-05 1.9132349e-05 4.4482469e-05 + 1.8182898e-06 2.2834571e-05 1.1466043e-05 1.2504059e-01 1.9084632e-04] +{'airplane': 87.33, 'ship': 12.5, 'automobile': 0.13, 'truck': 0.02, 'bird': 0.01, 'deer': 0.0, 'frog': 0.0, 'cat': 0.0, 'horse': 0.0, 'dog': 0.0} + +``` +
+**View predictions** + + +```python +plt.figure(figsize=(10, 10)) +for images in predict_ds: + for i in range(min(6, probabilities.shape[0])): + ax = plt.subplot(3, 3, i + 1) + plt.imshow(images[i].numpy().astype("uint8")) + predicted_class = get_predicted_class(probabilities[i]) + plt.title(predicted_class) + plt.axis("off") +``` + + +![png](/img/examples/vision/shiftvit/shiftvit_46_0.png) + + --- ## Conclusion @@ -990,3 +1203,9 @@ GPU credits. library. - A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for helping us with the Learning Rate Schedule. + +**Example available on HuggingFace** + +| Trained Model | Demo | +| :--: | :--: | +| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) | diff --git a/examples/vision/shiftvit.py b/examples/vision/shiftvit.py index 3d3476d90c..e6400449c8 100644 --- a/examples/vision/shiftvit.py +++ b/examples/vision/shiftvit.py @@ -1,8 +1,8 @@ """ Title: A Vision Transformer without Attention -Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha) +Author: [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/) Date created: 2022/02/24 -Last modified: 2022/03/01 +Last modified: 2022/10/15 Description: A minimal implementation of ShiftViT. Accelerator: GPU """ @@ -26,12 +26,11 @@ In this example, we minimally implement the paper with close alignement to the author's [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py). -This example requires TensorFlow 2.6 or higher, as well as TensorFlow Addons, which can +This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can be installed using the following command: - -```shell +""" +"""shell pip install -qq -U tensorflow-addons -``` """ """ @@ -44,9 +43,11 @@ import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers - import tensorflow_addons as tfa +import pathlib +import glob + # Setting seed for reproducibiltiy SEED = 42 keras.utils.set_random_seed(SEED) @@ -88,6 +89,21 @@ class Config(object): # TRAINING epochs = 100 + # INFERENCE + label_map = { + 0: "airplane", + 1: "automobile", + 2: "bird", + 3: "cat", + 4: "deer", + 5: "dog", + 6: "frog", + 7: "horse", + 8: "ship", + 9: "truck", + } + tf_ds_batch_size = 20 + config = Config() @@ -200,7 +216,7 @@ def get_augmentation_model(): """ #### The MLP block -The MLP block is intended to be a stack of densely-connected layers.s +The MLP block is intended to be a stack of densely-connected layers """ @@ -273,7 +289,7 @@ def call(self, x, training=False): """ #### Block -The most important operation in this paper is the **shift opperation**. In this section, +The most important operation in this paper is the **shift operation**. In this section, we describe the shift operation and compare it with its original implementation provided by the authors. @@ -545,6 +561,24 @@ def call(self, x, training=False): x = self.patch_merge(x) return x + # Since this is a custom layer, we need to overwrite get_config() + # so that model can be easily saved & loaded after training + def get_config(self): + config = super().get_config() + config.update( + { + "epsilon": self.epsilon, + "mlp_dropout_rate": self.mlp_dropout_rate, + "num_shift_blocks": self.num_shift_blocks, + "stochastic_depth_rate": self.stochastic_depth_rate, + "is_merge": self.is_merge, + "num_div": self.num_div, + "shift_pixel": self.shift_pixel, + "mlp_expand_ratio": self.mlp_expand_ratio, + } + ) + return config + """ ## The ShiftViT model @@ -617,6 +651,8 @@ def __init__( ) self.global_avg_pool = layers.GlobalAveragePooling2D() + self.classifier = layers.Dense(config.num_classes) + def get_config(self): config = super().get_config() config.update( @@ -625,6 +661,7 @@ def get_config(self): "patch_projection": self.patch_projection, "stages": self.stages, "global_avg_pool": self.global_avg_pool, + "classifier": self.classifier, } ) return config @@ -644,7 +681,8 @@ def _calculate_loss(self, data, training=False): x = stage(x, training=training) # Get the logits. - logits = self.global_avg_pool(x) + x = self.global_avg_pool(x) + logits = self.classifier(x) # Calculate the loss and return it. total_loss = self.compiled_loss(labels, logits) @@ -661,6 +699,7 @@ def train_step(self, inputs): self.data_augmentation.trainable_variables, self.patch_projection.trainable_variables, self.global_avg_pool.trainable_variables, + self.classifier.trainable_variables, ] train_vars = train_vars + [stage.trainable_variables for stage in self.stages] @@ -683,6 +722,15 @@ def test_step(self, data): self.compiled_metrics.update_state(labels, logits) return {m.name: m.result() for m in self.metrics} + def call(self, images): + augmented_images = self.data_augmentation(images) + x = self.patch_projection(augmented_images) + for stage in self.stages: + x = stage(x, training=False) + x = self.global_avg_pool(x) + logits = self.classifier(x) + return logits + """ ## Instantiate the model @@ -787,11 +835,25 @@ def __call__(self, step): step > self.total_steps, 0.0, learning_rate, name="learning_rate" ) + def get_config(self): + config = { + "lr_start": self.lr_start, + "lr_max": self.lr_max, + "total_steps": self.total_steps, + "warmup_steps": self.warmup_steps, + } + return config + """ ## Compile and train the model """ +# pass sample data to the model so that input shape is available at the time of +# saving the model +sample_ds, _ = next(iter(train_ds)) +model(sample_ds, training=False) + # Get the total number of steps for training. total_steps = int((len(x_train) / config.batch_size) * config.epochs) @@ -843,6 +905,123 @@ def __call__(self, step): print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%") print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%") +""" +## Save trained model + +Since we created the model by Subclassing, we can't save the model in HDF5 format. + +It can be saved in TF SavedModel format only. In general, this is the recommended format for saving models as well. +""" +model.save("ShiftViT") + +""" +## Model inference +""" + +""" +**Download sample data for inference** +""" + +"""shell +wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip +unzip -q inference_set.zip +""" + + +""" +**Load saved model** +""" +# Custom objects are not included when the model is saved. +# At loading time, these objects need to be passed for reconstruction of the model +saved_model = tf.keras.models.load_model( + "ShiftViT", + custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW}, +) + +""" +**Utility functions for inference** +""" + + +def process_image(img_path): + # read image file from string path + img = tf.io.read_file(img_path) + + # decode jpeg to uint8 tensor + img = tf.io.decode_jpeg(img, channels=3) + + # resize image to match input size accepted by model + # use `method` as `nearest` to preserve dtype of input passed to `resize()` + img = tf.image.resize( + img, [config.input_shape[0], config.input_shape[1]], method="nearest" + ) + return img + + +def create_tf_dataset(image_dir): + data_dir = pathlib.Path(image_dir) + + # create tf.data dataset using directory of images + predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False) + + # use map to convert string paths to uint8 image tensors + # setting `num_parallel_calls' helps in processing multiple images parallely + predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO) + + # create a Prefetch Dataset for better latency & throughput + predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO) + return predict_ds + + +def predict(predict_ds): + # ShiftViT model returns logits (non-normalized predictions) + logits = saved_model.predict(predict_ds) + + # normalize predictions by calling softmax() + probabilities = tf.nn.softmax(logits) + return probabilities + + +def get_predicted_class(probabilities): + pred_label = np.argmax(probabilities) + predicted_class = config.label_map[pred_label] + return predicted_class + + +def get_confidence_scores(probabilities): + # get the indices of the probability scores sorted in descending order + labels = np.argsort(probabilities)[::-1] + confidences = { + config.label_map[label]: np.round((probabilities[label]) * 100, 2) + for label in labels + } + return confidences + + +""" +**Get predictions** +""" + +img_dir = "inference_set" +predict_ds = create_tf_dataset(img_dir) +probabilities = predict(predict_ds) +print(f"probabilities: {probabilities[0]}") +confidences = get_confidence_scores(probabilities[0]) +print(confidences) + +""" +**View predictions** +""" + +plt.figure(figsize=(10, 10)) +for images in predict_ds: + for i in range(min(6, probabilities.shape[0])): + ax = plt.subplot(3, 3, i + 1) + plt.imshow(images[i].numpy().astype("uint8")) + predicted_class = get_predicted_class(probabilities[i]) + plt.title(predicted_class) + plt.axis("off") + """ ## Conclusion @@ -866,3 +1045,11 @@ def __call__(self, step): - A personal note of thanks to [Puja Roychowdhury](https://twitter.com/pleb_talks) for helping us with the Learning Rate Schedule. """ + +""" +**Example available on HuggingFace** + +| Trained Model | Demo | +| :--: | :--: | +| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-ShiftViT-brightgreen)](https://huggingface.co/keras-io/shiftvit) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Space-ShiftViT-brightgreen)](https://huggingface.co/spaces/keras-io/shiftvit) | +"""