Skip to content

Commit cbefe01

Browse files
author
Vincent Moens
committed
[Setup] Making tensordict pytorch-agnostic
ghstack-source-id: 16e1d43 Pull Request resolved: #1256
1 parent cb81b5e commit cbefe01

File tree

10 files changed

+109
-48
lines changed

10 files changed

+109
-48
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/bin/bash
2+
3+
yum update gcc
4+
yum update libstdc++
5+
6+
${CONDA_RUN} pip install cmake pybind11

.github/scripts/win-pre-script.bat

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@echo off
2+
:: Check if CONDA_RUN is set, if not, set it to a default value
3+
if "%CONDA_RUN%"=="" (
4+
echo CONDA_RUN is not set. Please activate your conda environment or set CONDA_RUN.
5+
exit /b 1
6+
)
7+
8+
:: Run the pip install command
9+
%CONDA_RUN% pip install cmake pybind11
10+
11+
:: Check if the installation was successful
12+
if errorlevel 1 (
13+
echo Failed to install cmake and pybind11.
14+
exit /b 1
15+
) else (
16+
echo Successfully installed cmake and pybind11.
17+
)

.github/unittest/linux/scripts/environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ dependencies:
1919
- coverage
2020
- h5py
2121
- orjson
22+
- ninja
23+
- pybind11
24+
- cmake

.github/unittest/linux_torchrec/scripts/environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ dependencies:
1818
- coverage
1919
- h5py
2020
- orjson
21+
- cmake
22+
- ninja
23+
- pybind11

.github/unittest/rl_linux_optdeps/scripts/environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ dependencies:
1616
- pyyaml
1717
- scipy
1818
- orjson
19+
- cmake
20+
- ninja
21+
- pybind11

.github/workflows/build-wheels-linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
include:
3434
- repository: pytorch/tensordict
3535
smoke-test-script: test/smoke_test.py
36+
pre-script: .github/scripts/linux-pre-script.sh
3637
post-script: .github/scripts/linux-post-script.sh
3738
package-name: tensordict
3839
name: pytorch/tensordict

.github/workflows/build-wheels-windows.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
matrix:
3333
include:
3434
- repository: pytorch/tensordict
35-
pre-script: ""
35+
pre-script: .github/scripts/win-pre-script.bat
3636
env-script: .github/scripts/version_script.bat
3737
post-script: "python packaging/wheel/relocate.py"
3838
smoke-test-script: test/smoke_test.py

CONTRIBUTING.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ possible.
66
Install the library as suggested in the README. For advanced features,
77
it is preferable to install the nightly built of pytorch.
88

9+
You will need the following packages to be installed:
10+
```bash
11+
pip install ninja cmake pybind11 -U
12+
```
13+
914
Make sure you install tensordict in develop mode by running
1015
```
1116
python setup.py develop

setup.py

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import argparse
77
import distutils.command.clean
8-
import glob
98
import logging
109
import os
1110
import shutil
@@ -15,11 +14,23 @@
1514
from pathlib import Path
1615
from typing import List
1716

18-
from setuptools import find_packages, setup
19-
from torch.utils.cpp_extension import BuildExtension, CppExtension
17+
from setuptools import Extension, find_packages, setup
18+
from setuptools.command.build_ext import build_ext
2019

2120
ROOT_DIR = Path(__file__).parent.resolve()
2221

22+
23+
def get_python_executable():
24+
# Check if we're running in a virtual environment
25+
if "VIRTUAL_ENV" in os.environ:
26+
# Get the virtual environment's Python executable
27+
python_executable = os.path.join(os.environ["VIRTUAL_ENV"], "bin", "python")
28+
else:
29+
# Fall back to sys.executable
30+
python_executable = sys.executable
31+
return python_executable
32+
33+
2334
try:
2435
sha = (
2536
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=ROOT_DIR)
@@ -69,7 +80,7 @@ def _get_pytorch_version(is_nightly, is_local):
6980
return "torch>=2.7.0.dev"
7081
if is_local:
7182
return "torch"
72-
return "torch>=2.6.0"
83+
return "torch>=2.5.0"
7384

7485

7586
def _get_packages():
@@ -99,51 +110,38 @@ def run(self):
99110
shutil.rmtree(str(path), ignore_errors=True)
100111

101112

102-
def get_extensions():
103-
extension = CppExtension
104-
105-
extra_link_args = []
106-
extra_compile_args = {
107-
"cxx": [
108-
"-O3",
109-
"-std=c++17",
110-
"-fdiagnostics-color=always",
113+
class CMakeExtension(Extension):
114+
def __init__(self, name, sourcedir=""):
115+
super().__init__(name, sources=[])
116+
self.sourcedir = os.path.abspath(sourcedir)
117+
118+
119+
class CMakeBuild(build_ext):
120+
def run(self):
121+
for ext in self.extensions:
122+
self.build_extension(ext)
123+
124+
def build_extension(self, ext):
125+
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
126+
cmake_args = [
127+
f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
128+
f"-DPYTHON_EXECUTABLE={get_python_executable()}",
129+
f"-DPython3_EXECUTABLE={get_python_executable()}",
111130
]
112-
}
113-
debug_mode = os.getenv("DEBUG", "0") == "1"
114-
if debug_mode:
115-
logging.info("Compiling in debug mode")
116-
extra_compile_args = {
117-
"cxx": [
118-
"-O0",
119-
"-fno-inline",
120-
"-g",
121-
"-std=c++17",
122-
"-fdiagnostics-color=always",
123-
]
124-
}
125-
extra_link_args = ["-O0", "-g"]
126-
127-
this_dir = os.path.dirname(os.path.abspath(__file__))
128-
extensions_dir = os.path.join(this_dir, "tensordict", "csrc")
129-
130-
extension_sources = {
131-
os.path.join(extensions_dir, p)
132-
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
133-
}
134-
sources = list(extension_sources)
135-
136-
ext_modules = [
137-
extension(
138-
"tensordict._C",
139-
sources,
140-
include_dirs=[this_dir],
141-
extra_compile_args=extra_compile_args,
142-
extra_link_args=extra_link_args,
131+
build_args = []
132+
if not os.path.exists(self.build_temp):
133+
os.makedirs(self.build_temp)
134+
subprocess.check_call(
135+
["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp
136+
)
137+
subprocess.check_call(
138+
["cmake", "--build", "."] + build_args, cwd=self.build_temp
143139
)
144-
]
145140

146-
return ext_modules
141+
142+
def get_extensions():
143+
extensions_dir = os.path.join(ROOT_DIR, "tensordict", "csrc")
144+
return [CMakeExtension("tensordict._C", sourcedir=extensions_dir)]
147145

148146

149147
def _main(argv):
@@ -181,7 +179,7 @@ def _main(argv):
181179
),
182180
ext_modules=get_extensions(),
183181
cmdclass={
184-
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
182+
"build_ext": CMakeBuild,
185183
"clean": clean,
186184
},
187185
install_requires=[

tensordict/csrc/CMakeLists.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
cmake_minimum_required(VERSION 3.12)
2+
project(tensordict)
3+
4+
set(CMAKE_CXX_STANDARD 17)
5+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
6+
7+
# Set the Python executable to the one from your virtual environment
8+
# set(Python3_EXECUTABLE "/Users/vmoens/venv/rl2/bin/python3.10")
9+
10+
find_package(Python3 REQUIRED COMPONENTS Interpreter Development)
11+
find_package(pybind11 REQUIRED)
12+
13+
file(GLOB SOURCES "*.cpp")
14+
15+
add_library(_C MODULE ${SOURCES})
16+
17+
set_target_properties(_C PROPERTIES
18+
OUTPUT_NAME "_C"
19+
LIBRARY_OUTPUT_DIRECTORY "${PROJECT_SOURCE_DIR}/../"
20+
PREFIX "" # Remove 'lib' prefix
21+
SUFFIX ".so" # Ensure correct suffix for macOS/Linux
22+
)
23+
24+
target_include_directories(_C PRIVATE ${PROJECT_SOURCE_DIR})
25+
target_link_libraries(_C PRIVATE Python3::Python pybind11::module)

0 commit comments

Comments
 (0)