Merge remote-tracking branch 'b/nv_upstream_trt_cuequivariance' into nv_upstream_trt_cuequivariance

This commit is contained in:
Boris Fomitchev
2025-11-07 13:38:19 -08:00
2 changed files with 5 additions and 5 deletions

View File

@@ -546,8 +546,8 @@ def embed_templates_offload(
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_cuequivariance_attention=model.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=model.globals.use_cuequivariance_multiplicative_update,
use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans,
@@ -667,8 +667,8 @@ def embed_templates_average(
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_cuequivariance_attention=model.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=model.globals.use_cuequivariance_multiplicative_update,
use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans,

View File

@@ -46,7 +46,7 @@ from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.data.tools import hhsearch, hmmsearch
from openfold.np import protein
from openfold.utils.script_utils import (load_models_from_command_line, parse_fasta, run_model,
prep_output)
prep_output, relax_protein)
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.trace_utils import (
pad_feature_dict_seq,