From 2fc241af6d32dcaa3259935a44a661ba7eeb3d6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Tubiana?= Date: Wed, 10 Nov 2021 18:00:29 +0200 Subject: [PATCH] Support for attention coefficient extraction --- .DS_Store | Bin 18436 -> 18436 bytes .gitignore | 4 +++- predict_bindingsites.py | 7 ++++--- predictions/.DS_Store | Bin 6148 -> 10244 bytes 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.DS_Store b/.DS_Store index 6fbc6bc891b91466387942564bf3b61c103af8c4..f1a836d205835bf20bf14ff5ee8bb21a7a3b3468 100644 GIT binary patch delta 915 zcmZvaOH30%9LD$CBFruUrct-BR9zH;)GBQODTo^7q4)qF5b+UB%PzWFTIeqLKt&CS zCl#DTO+29&FD7aeZyt;qZ^j@84r+YdG$tk{CdNda-S(h~m)(5-Z|47>`F%4niHS+v ztMRg~m4l^aV`JfHB&db!BWHrzSUA!U4Qh>{de&Xpz&?7DKJiO@C}CQ5C^QGrlQ%EF zfJkzo!Zt-GA+xz?&A3-AmZVu)TvDX?lrln>k-m|*IXq(8CufU_OK7=Ykq_uG+t{Ac z%>mQUX@*FyqxltOzvHx~4gS>$am$reirizFrwkg)EvyO1a?dd{sZ)!Oa+3!hnW-e*nkz<|WTme^H8^M_ZOe;buDq6yKIn|z7PIW`AtUKd zcQ4^rS_#?fcpG%x(A{zE9g4EgIuc78-N%zgn$F~ox}`&@Qx4G**K%65>!zJb=iDy7 z>ky@8a@4eAHb2#F{u{9fVZ#f2>~)L8I-8e|&j_!CS>cWFR(L187d{E!AV7o%c_=~| zRFtCPA8+{^qH&ui=Gp00r>EhffunO{EL!(M6=BvRMd5Y+0!Y}&h#wUy7$wE3TVrp_(O6l%3a>}SQl PD>4PEJYuFW^uzZDSnl8+ delta 709 zcmZ|NT}YE*6bJD0f7i#}XIih*ZO-K!vx(WJb9={9XW1+blA7gAi85=WmO3dR$yFwp zR0J6h_I=lfbtC3Lx)916p{xr-D~P(tE+Q%lgdn=;!9`c~TpW1*=Ws55oc=-d4`Osp z1`l~Vbml~&)Z=#AMB!7eAfL~Ob5+e!n9dSzWKB)61D)}mv7RXJ(+Bxvfr0lKGVS?Q zZd!cuNSYznUa;2fD_>t-TNj#5U$IhQdX~jojM;oa-2%odeeUh7&{8BTt=fhzu_`s> zVkNTDB{#Ii;*rI~;$f;*iOAvh*8S1;RODx6T4k%e>sa(?SFAIh3iXUn_=bCB1*9c2 z*-51Ug{YY#)J+#?h%VC|8liETpedT67c@sN=?#6LPqaY4uoTOog$zA1VSo|&aKeRR zl%pD(P=i|3A&f>ep&5H|5QlIWZ8(A+BybX^Fo25~#3dEiaSOL`7sD9GBRs)VOk);v zc!k%P$7g)OSBWH9VvN_kSD4bn9 zz*(Hr<2*Qqa0OR!Z4qpg-!tXm0sp4tB;T7N8cg#U+v&8NM8c!`RHx15EfaTW#SQ-5 el82kgj8$(nNu8}Eou6?$Pil)bQj)oUntlT_qOA!4 diff --git a/.gitignore b/.gitignore index c8df503..46bf28c 100644 --- a/.gitignore +++ b/.gitignore @@ -132,4 +132,6 @@ models/*check* models/*handcrafted* data/pdb/*.txt -data/tmp/* \ No newline at end of file +data/tmp/* +predictions/.DS_Store +.DS_Store diff --git a/predict_bindingsites.py b/predict_bindingsites.py index 2027cdc..d1ba661 100644 --- a/predict_bindingsites.py +++ b/predict_bindingsites.py @@ -74,7 +74,7 @@ def write_predictions(csv_file, residue_ids, sequence, interface_prediction): if interface_prediction.ndim == 1: columns.append('Binding site probability') else: - columns += ['Output %s' %i for i in range(len(interface_prediction) )] + columns += ['Output %s' %i for i in range(interface_prediction.shape[-1] )] with open(csv_file, 'w') as f: f.write(','.join(columns) + '\n' ) @@ -492,7 +492,7 @@ def predict_interface_residues( for attention_coeff,neighborhood_graph in zip(attention_coeffs,neighborhood_graphs): aggregated_attention_coeff = np.zeros(len(attention_coeff),dtype=np.float32) for s in range( len(attention_coeff) ): - aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s],0).mean( + aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s][:len(neighborhood_graph[s])],0).mean( -1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads. # aggregated_attention_coeff[neighborhood_graph[s]] += np.abs(attention_coeff[s]).mean(-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads. aggregated_attention_coeffs.append(aggregated_attention_coeff) @@ -556,7 +556,7 @@ def predict_interface_residues( for attention_coeff,neighborhood_graph in zip(attention_coeffs,neighborhood_graphs): aggregated_attention_coeff = np.zeros(len(attention_coeff),dtype=np.float32) for s in range( len(attention_coeff) ): - aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s],0).mean(-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads. + aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s][:len(neighborhood_graph[s])],0).mean(-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads. aggregated_attention_coeffs.append(aggregated_attention_coeff) aggregated_attention_coeffs = np.array(aggregated_attention_coeffs) if layer == 'attention_layer': @@ -611,6 +611,7 @@ def predict_interface_residues( if not aggregate_models: # multioutput. for layer_,prediction in zip(layer,predictions): if layer_ is None: + prediction = prediction[:,1] csv_file = query_output_folder + 'predictions_' + query_name + '.csv' chimera_file = query_output_folder + 'chimera_' + query_names[i] annotated_pdb_file = query_output_folder + 'annotated_' + query_names[i] + ('.cif' if file_is_cif else '.pdb') diff --git a/predictions/.DS_Store b/predictions/.DS_Store index 440595fc79f064acbb1808a1e35f2d210a82f559..253b917ff60f1f3094951a6c93618dbd3375faf2 100644 GIT binary patch literal 10244 zcmeHM&2G~`5FWPyoB*OoNCk<54$^j%0QV<*frE!wB7O7pu4go?? z9)q730$zaE;hSAI1cxeZN!kj{O1raO@9f&&=bhad6OmDB*Bv5@h)iUbbC*#3CGmaE zBWWgXxe94OpQu6$WYZF@)8a^PX%GYi0YN|z5CjB)vw#5hY)cT%?>?Fbtj{!rIdnzATWvm`|dZ%p(-uY+dusG3YDpbcAp+ogCgAX zP>)b=(H7ck)ToEGrW*uJ9(Ht)5`q9n9OTZ>lem~Kpf=|8)a;>}f18?kfX%B-{ASX73+X?sFX5a#eZ|+{x3nSb9Y0S)~&Yiztnwd=Ia^_WeC+L); zYSa!()z%BXdQy#c+_?S}Ko06v_nFt|ROYVT3;d|+`HhxR_Ua9EdHLM)>p`a!wEcRh z#wqI?rkPHsD|7SR?%G<`TFdpaR@Yg!QD0eI?)B2<{B7s%gU#J{@4Fv*A3rf|u#~Yy zr7QY+P<}>jN@%?)xBO<{Z}arX{O!2-tU_?|JAiZul9R;N=rx8CVikvDiJbyk31a#4 z*#al7aH1#~zu2aLR)UjTlmjC>u#KbZWUR%9&cLb;nYqujTh57oZ=58LG!V$GA<%S`NLAl6yU zf!Ms>&2oiR5Nqd;MeG!?O0cn{4otj6BGmY>R0l77yxLXAtI4qP{j`_cbu2e$W&9C? zj}S-)?nJyArwCbs9Q7uvacmtkWTz(K)KQ|D*c