Skip to content

Commit 0f10863

Browse files
author
Vincent Moens
committed
[Setup] Making tensordict pytorch-agnostic
ghstack-source-id: 599c5ac Pull Request resolved: #1256
1 parent cb81b5e commit 0f10863

File tree

2 files changed

+70
-47
lines changed

2 files changed

+70
-47
lines changed

setup.py

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import argparse
77
import distutils.command.clean
8-
import glob
98
import logging
109
import os
1110
import shutil
@@ -15,11 +14,23 @@
1514
from pathlib import Path
1615
from typing import List
1716

18-
from setuptools import find_packages, setup
19-
from torch.utils.cpp_extension import BuildExtension, CppExtension
17+
from setuptools import Extension, find_packages, setup
18+
from setuptools.command.build_ext import build_ext
2019

2120
ROOT_DIR = Path(__file__).parent.resolve()
2221

22+
23+
def get_python_executable():
24+
# Check if we're running in a virtual environment
25+
if "VIRTUAL_ENV" in os.environ:
26+
# Get the virtual environment's Python executable
27+
python_executable = os.path.join(os.environ["VIRTUAL_ENV"], "bin", "python")
28+
else:
29+
# Fall back to sys.executable
30+
python_executable = sys.executable
31+
return python_executable
32+
33+
2334
try:
2435
sha = (
2536
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=ROOT_DIR)
@@ -69,7 +80,7 @@ def _get_pytorch_version(is_nightly, is_local):
6980
return "torch>=2.7.0.dev"
7081
if is_local:
7182
return "torch"
72-
return "torch>=2.6.0"
83+
return "torch>=2.5.0"
7384

7485

7586
def _get_packages():
@@ -99,51 +110,38 @@ def run(self):
99110
shutil.rmtree(str(path), ignore_errors=True)
100111

101112

102-
def get_extensions():
103-
extension = CppExtension
104-
105-
extra_link_args = []
106-
extra_compile_args = {
107-
"cxx": [
108-
"-O3",
109-
"-std=c++17",
110-
"-fdiagnostics-color=always",
113+
class CMakeExtension(Extension):
114+
def __init__(self, name, sourcedir=""):
115+
super().__init__(name, sources=[])
116+
self.sourcedir = os.path.abspath(sourcedir)
117+
118+
119+
class CMakeBuild(build_ext):
120+
def run(self):
121+
for ext in self.extensions:
122+
self.build_extension(ext)
123+
124+
def build_extension(self, ext):
125+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
126+
cmake_args = [
127+
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
128+
f"-DPYTHON_EXECUTABLE={get_python_executable()}",
129+
f"-DPython3_EXECUTABLE={get_python_executable()}",
111130
]
112-
}
113-
debug_mode = os.getenv("DEBUG", "0") == "1"
114-
if debug_mode:
115-
logging.info("Compiling in debug mode")
116-
extra_compile_args = {
117-
"cxx": [
118-
"-O0",
119-
"-fno-inline",
120-
"-g",
121-
"-std=c++17",
122-
"-fdiagnostics-color=always",
123-
]
124-
}
125-
extra_link_args = ["-O0", "-g"]
126-
127-
this_dir = os.path.dirname(os.path.abspath(__file__))
128-
extensions_dir = os.path.join(this_dir, "tensordict", "csrc")
129-
130-
extension_sources = {
131-
os.path.join(extensions_dir, p)
132-
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
133-
}
134-
sources = list(extension_sources)
135-
136-
ext_modules = [
137-
extension(
138-
"tensordict._C",
139-
sources,
140-
include_dirs=[this_dir],
141-
extra_compile_args=extra_compile_args,
142-
extra_link_args=extra_link_args,
131+
build_args = []
132+
if not os.path.exists(self.build_temp):
133+
os.makedirs(self.build_temp)
134+
subprocess.check_call(
135+
["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp
136+
)
137+
subprocess.check_call(
138+
["cmake", "--build", "."] + build_args, cwd=self.build_temp
143139
)
144-
]
145140

146-
return ext_modules
141+
142+
def get_extensions():
143+
extensions_dir = os.path.join(ROOT_DIR, "tensordict", "csrc")
144+
return [CMakeExtension("tensordict._C", sourcedir=extensions_dir)]
147145

148146

149147
def _main(argv):
@@ -181,7 +179,7 @@ def _main(argv):
181179
),
182180
ext_modules=get_extensions(),
183181
cmdclass={
184-
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
182+
"build_ext": CMakeBuild,
185183
"clean": clean,
186184
},
187185
install_requires=[

tensordict/csrc/CMakeLists.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
cmake_minimum_required(VERSION 3.12)
2+
project(tensordict)
3+
4+
set(CMAKE_CXX_STANDARD 17)
5+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
6+
7+
# Set the Python executable to the one from your virtual environment
8+
# set(Python3_EXECUTABLE "/Users/vmoens/venv/rl2/bin/python3.10")
9+
10+
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
11+
find_package(pybind11 REQUIRED)
12+
13+
file(GLOB SOURCES "*.cpp")
14+
15+
add_library(_C MODULE ${SOURCES})
16+
17+
set_target_properties(_C PROPERTIES
18+
OUTPUT_NAME "_C"
19+
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/../"
20+
PREFIX "" # Remove 'lib' prefix
21+
SUFFIX ".so" # Ensure correct suffix for macOS/Linux
22+
)
23+
24+
target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR})
25+
target_link_libraries(_C PRIVATE Python3::Python pybind11::module)

0 commit comments

Comments
 (0)