代码拉取完成,页面将自动刷新
import torch
from alphafold3 import AlphaFold3
# Create random tensors
x = torch.randn(
1, 5, 5, 64
) # Shape: (batch_size, seq_len, seq_len, dim)
y = torch.randn(1, 5, 64) # Shape: (batch_size, seq_len, dim)
# Initialize AlphaFold3 model
model = AlphaFold3(
dim=64,
seq_len=5,
heads=8,
dim_head=64,
attn_dropout=0.0,
ff_dropout=0.0,
global_column_attn=False,
pair_former_depth=48,
num_diffusion_steps=1000,
diffusion_depth=30,
)
# Forward pass through the model
output = model(x, y, return_confidence=True)
# Print the shape of the output tensor
print(output)
print(output.shape)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。