Skip to content

Commit d1a0c6f

Browse files
author
Vincent Moens
committed
[Setup] Making tensordict pytorch-agnostic
ghstack-source-id: d6e8eb6 Pull Request resolved: #1256
1 parent cb81b5e commit d1a0c6f

File tree

10 files changed

+131
-47
lines changed

10 files changed

+131
-47
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
3+
yum update gcc
4+
yum update libstdc++
5+
6+
conda install conda-forge::pybind11 -y

.github/scripts/win-pre-script.bat

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@echo off
2+
:: Check if CONDA_RUN is set, if not, set it to a default value
3+
if "%CONDA_RUN%"=="" (
4+
echo CONDA_RUN is not set. Please activate your conda environment or set CONDA_RUN.
5+
exit /b 1
6+
)
7+
8+
:: Run the pip install command
9+
%CONDA_RUN% python -m pip install cmake pybind11 -U
10+
11+
:: Check if the installation was successful
12+
if errorlevel 1 (
13+
echo Failed to install cmake and pybind11.
14+
exit /b 1
15+
) else (
16+
echo Successfully installed cmake and pybind11.
17+
)

.github/unittest/linux/scripts/environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ dependencies:
1919
- coverage
2020
- h5py
2121
- orjson
22+
- ninja
23+
- pybind11
24+
- cmake

.github/unittest/linux_torchrec/scripts/environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ dependencies:
1818
- coverage
1919
- h5py
2020
- orjson
21+
- cmake
22+
- ninja
23+
- pybind11

.github/unittest/rl_linux_optdeps/scripts/environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ dependencies:
1616
- pyyaml
1717
- scipy
1818
- orjson
19+
- cmake
20+
- ninja
21+
- pybind11

.github/workflows/build-wheels-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
include:
3434
- repository: pytorch/tensordict
3535
smoke-test-script: test/smoke_test.py
36+
pre-script: .github/scripts/linux-pre-script.sh
3637
post-script: .github/scripts/linux-post-script.sh
3738
package-name: tensordict
3839
name: pytorch/tensordict

.github/workflows/build-wheels-windows.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
matrix:
3333
include:
3434
- repository: pytorch/tensordict
35-
pre-script: ""
35+
pre-script: .github/scripts/win-pre-script.bat
3636
env-script: .github/scripts/version_script.bat
3737
post-script: "python packaging/wheel/relocate.py"
3838
smoke-test-script: test/smoke_test.py

CONTRIBUTING.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ possible.
66
Install the library as suggested in the README. For advanced features,
77
it is preferable to install the nightly built of pytorch.
88

9+
You will need the following packages to be installed:
10+
```bash
11+
pip install ninja cmake pybind11 -U
12+
```
13+
914
Make sure you install tensordict in develop mode by running
1015
```
1116
python setup.py develop

setup.py

Lines changed: 67 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,23 @@
1515
from pathlib import Path
1616
from typing import List
1717

18-
from setuptools import find_packages, setup
19-
from torch.utils.cpp_extension import BuildExtension, CppExtension
18+
from setuptools import Extension, find_packages, setup
19+
from setuptools.command.build_ext import build_ext
2020

2121
ROOT_DIR = Path(__file__).parent.resolve()
2222

23+
24+
def get_python_executable():
25+
# Check if we're running in a virtual environment
26+
if "VIRTUAL_ENV" in os.environ:
27+
# Get the virtual environment's Python executable
28+
python_executable = os.path.join(os.environ["VIRTUAL_ENV"], "bin", "python")
29+
else:
30+
# Fall back to sys.executable
31+
python_executable = sys.executable
32+
return python_executable
33+
34+
2335
try:
2436
sha = (
2537
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=ROOT_DIR)
@@ -69,7 +81,7 @@ def _get_pytorch_version(is_nightly, is_local):
6981
return "torch>=2.7.0.dev"
7082
if is_local:
7183
return "torch"
72-
return "torch>=2.6.0"
84+
return "torch>=2.5.0"
7385

7486

7587
def _get_packages():
@@ -99,51 +111,60 @@ def run(self):
99111
shutil.rmtree(str(path), ignore_errors=True)
100112

101113

102-
def get_extensions():
103-
extension = CppExtension
104-
105-
extra_link_args = []
106-
extra_compile_args = {
107-
"cxx": [
108-
"-O3",
109-
"-std=c++17",
110-
"-fdiagnostics-color=always",
114+
class CMakeExtension(Extension):
115+
def __init__(self, name, sourcedir=""):
116+
super().__init__(name, sources=[])
117+
self.sourcedir = os.path.abspath(sourcedir)
118+
119+
120+
class CMakeBuild(build_ext):
121+
def run(self):
122+
for ext in self.extensions:
123+
self.build_extension(ext)
124+
125+
def build_extension(self, ext):
126+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
127+
cmake_args = [
128+
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
129+
f"-DPYTHON_EXECUTABLE={get_python_executable()}",
130+
f"-DPython3_EXECUTABLE={get_python_executable()}",
111131
]
112-
}
113-
debug_mode = os.getenv("DEBUG", "0") == "1"
114-
if debug_mode:
115-
logging.info("Compiling in debug mode")
116-
extra_compile_args = {
117-
"cxx": [
118-
"-O0",
119-
"-fno-inline",
120-
"-g",
121-
"-std=c++17",
122-
"-fdiagnostics-color=always",
123-
]
124-
}
125-
extra_link_args = ["-O0", "-g"]
126-
127-
this_dir = os.path.dirname(os.path.abspath(__file__))
128-
extensions_dir = os.path.join(this_dir, "tensordict", "csrc")
129-
130-
extension_sources = {
131-
os.path.join(extensions_dir, p)
132-
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
133-
}
134-
sources = list(extension_sources)
135-
136-
ext_modules = [
137-
extension(
138-
"tensordict._C",
139-
sources,
140-
include_dirs=[this_dir],
141-
extra_compile_args=extra_compile_args,
142-
extra_link_args=extra_link_args,
132+
CONDA_PREFIX = os.environ.get("CONDA_PREFIX")
133+
# if CONDA_PREFIX:
134+
# CMAKE_PREFIX_PATH = os.environ.get("CMAKE_PREFIX_PATH")
135+
# if CMAKE_PREFIX_PATH:
136+
# cmake_args.append(f"-DCMAKE_PREFIX_PATH={CONDA_PREFIX}:{CMAKE_PREFIX_PATH}")
137+
# else:
138+
# cmake_args.append(f"-DCMAKE_PREFIX_PATH={CONDA_PREFIX}")
139+
if CONDA_PREFIX:
140+
# Find pybind11
141+
pybind11_dir = None
142+
for config_file in ["pybind11Config.cmake", "pybind11-config.cmake"]:
143+
config_path = glob.glob(
144+
os.path.join(CONDA_PREFIX, "**", config_file), recursive=True
145+
)
146+
if config_path:
147+
pybind11_dir = os.path.dirname(config_path[0])
148+
break
149+
else:
150+
raise RuntimeError(f"could not find any of 'pybind11Config.cmake', 'pybind11-config.cmake'")
151+
if pybind11_dir:
152+
cmake_args.append(f"-DPYBIND11_DIR={pybind11_dir}")
153+
154+
build_args = []
155+
if not os.path.exists(self.build_temp):
156+
os.makedirs(self.build_temp)
157+
subprocess.check_call(
158+
["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp
159+
)
160+
subprocess.check_call(
161+
["cmake", "--build", "."] + build_args, cwd=self.build_temp
143162
)
144-
]
145163

146-
return ext_modules
164+
165+
def get_extensions():
166+
extensions_dir = os.path.join(ROOT_DIR, "tensordict", "csrc")
167+
return [CMakeExtension("tensordict._C", sourcedir=extensions_dir)]
147168

148169

149170
def _main(argv):
@@ -181,7 +202,7 @@ def _main(argv):
181202
),
182203
ext_modules=get_extensions(),
183204
cmdclass={
184-
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
205+
"build_ext": CMakeBuild,
185206
"clean": clean,
186207
},
187208
install_requires=[

tensordict/csrc/CMakeLists.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
cmake_minimum_required(VERSION 3.12)
2+
project(tensordict)
3+
4+
set(CMAKE_CXX_STANDARD 17)
5+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
6+
7+
# Set the Python executable to the one from your virtual environment
8+
# set(Python3_EXECUTABLE "/Users/vmoens/venv/rl2/bin/python3.10")
9+
10+
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
11+
find_package(pybind11 REQUIRED)
12+
13+
file(GLOB SOURCES "*.cpp")
14+
15+
add_library(_C MODULE ${SOURCES})
16+
17+
set_target_properties(_C PROPERTIES
18+
OUTPUT_NAME "_C"
19+
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/../"
20+
PREFIX "" # Remove 'lib' prefix
21+
SUFFIX ".so" # Ensure correct suffix for macOS/Linux
22+
)
23+
24+
target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR})
25+
target_link_libraries(_C PRIVATE Python3::Python pybind11::module)

0 commit comments

Comments
 (0)