mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Merge remote-tracking branch 'b/nv_upstream_trt_cuequivariance' into nv_upstream_trt_cuequivariance
This commit is contained in:
@@ -546,8 +546,8 @@ def embed_templates_offload(
|
||||
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=model.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_cuequivariance_attention=model.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=model.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=model.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=model.config._mask_trans,
|
||||
@@ -667,8 +667,8 @@ def embed_templates_average(
|
||||
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=model.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_cuequivariance_attention=model.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=model.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=model.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=model.config._mask_trans,
|
||||
|
||||
@@ -46,7 +46,7 @@ from openfold.data import templates, feature_pipeline, data_pipeline
|
||||
from openfold.data.tools import hhsearch, hmmsearch
|
||||
from openfold.np import protein
|
||||
from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model,
|
||||
prep_output)
|
||||
prep_output, relax_protein)
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
from openfold.utils.trace_utils import (
|
||||
pad_feature_dict_seq,
|
||||
|
||||
Reference in New Issue
Block a user