代码拉取完成,页面将自动刷新
import os
from setuptools import find_packages
from distutils.core import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
def make_cuda_ext(name, module, sources):
cuda_ext = CUDAExtension(
name='%s.%s' % (module, name),
sources=[os.path.join(*module.split('.'), src) for src in sources]
)
return cuda_ext
setup(
name='transformer4planning',
version='1.0.0',
author='QiaoSun & Shiduo-zh',
license="MIT",
packages=find_packages(),
author_email='',
description='',
install_requires=[],
cmdclass={
'build_ext': BuildExtension,
},
ext_modules=[
make_cuda_ext(
name='knn_cuda',
module='transformer4planning.libs.mtr.ops.knn',
sources=[
'src/knn.cpp',
'src/knn_gpu.cu',
'src/knn_api.cpp',
],
),
make_cuda_ext(
name='attention_cuda',
module='transformer4planning.libs.mtr.ops.attention',
sources=[
'src/attention_api.cpp',
'src/attention_func_v2.cpp',
'src/attention_func.cpp',
'src/attention_value_computation_kernel_v2.cu',
'src/attention_value_computation_kernel.cu',
'src/attention_weight_computation_kernel_v2.cu',
'src/attention_weight_computation_kernel.cu',
],
),
],
)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。