加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
model_example.py 631 Bytes
一键复制 编辑 原始数据 按行查看 历史
Kye 提交于 2024-05-09 21:41 . [EARLY PROTOTYPE]
import torch
from alphafold3 import AlphaFold3
# Create random tensors
x = torch.randn(
1, 5, 5, 64
) # Shape: (batch_size, seq_len, seq_len, dim)
y = torch.randn(1, 5, 64) # Shape: (batch_size, seq_len, dim)
# Initialize AlphaFold3 model
model = AlphaFold3(
dim=64,
seq_len=5,
heads=8,
dim_head=64,
attn_dropout=0.0,
ff_dropout=0.0,
global_column_attn=False,
pair_former_depth=48,
num_diffusion_steps=1000,
diffusion_depth=30,
)
# Forward pass through the model
output = model(x, y, return_confidence=True)
# Print the shape of the output tensor
print(output)
print(output.shape)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化