Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
048418e
delete tiramisu related files
skourta Dec 5, 2024
3aefdd0
add tiralib dep
skourta Dec 5, 2024
cd72e79
implement tiralib interface with gnn
skourta Dec 5, 2024
aac9ea7
remove unsused functions from graph utils
skourta Dec 5, 2024
54a4c00
update rollout worker to use tiramisu interface
skourta Dec 5, 2024
46e4582
update dataset actor
skourta Dec 5, 2024
ad6e887
update training script and config
skourta Dec 5, 2024
ffbe7ac
added cleaning up files
skourta Dec 6, 2024
287273f
use the new isl tree to get the graph
skourta Jan 9, 2025
33c3487
set lower and upper bound to 0 if not int
skourta Jan 9, 2025
c66b377
make embedding size 720 to avoid embedding crossing each other
skourta Jan 10, 2025
f076ce3
augmented embeddigns to 720 and done if depth of tree > max iterators
skourta Jan 10, 2025
069f219
Get computation embedding from initial annotations
skourta Jan 22, 2025
ef57396
added some logs
skourta Jan 23, 2025
65d27ca
add config option to clean files or not
skourta Jan 24, 2025
6c589fc
added cache
skourta Feb 5, 2025
d7b1197
added durations for different phases spent
skourta Feb 6, 2025
5ac11f7
add tiralib_config_oath to config.yaml.temp
skourta Feb 11, 2025
bc76f18
add readme file
skourta Feb 13, 2025
f8d410c
fix major bug and add evaluation file
skourta Feb 14, 2025
318d555
Merge branch 'tiralib-integration' of github.com:Tiramisu-Compiler/gn…
skourta Feb 14, 2025
2dd5307
fix train ppo
skourta Feb 14, 2025
84cd431
upgrade ruff
skourta Mar 3, 2025
10fbb98
update poetry.lock
skourta Mar 3, 2025
2cfb25f
refactoring and docs
skourta Mar 3, 2025
e5ea245
add tests for mask
skourta Mar 3, 2025
40eaff6
add test for init server
skourta Mar 3, 2025
4d3aadb
update isl write matrix helper and some more tests
skourta Mar 4, 2025
7d1271d
more unit tests and some refactoring
skourta Mar 5, 2025
c758b86
couple of bug fixes and tests
skourta Mar 6, 2025
d604f3e
finish tests for tiramisu_interface
skourta Mar 7, 2025
f2ffe18
code refactorign and tests for data actor
skourta Mar 10, 2025
8278c6a
rollout_worker tests and update poetry.lock
skourta Mar 14, 2025
5e46bcb
add compiling schedules to tiramisu interface
skourta Mar 19, 2025
ccf6e62
add test remove deadcode and unused condition
skourta Mar 24, 2025
eb81394
raise on server crash
skourta Apr 3, 2025
d496482
make unrolling start from 2 instaed of 1 in factor
skourta Apr 4, 2025
ddf0f8f
fixed unrolling tests and adding mocking execution everywhere
skourta Apr 7, 2025
28ddb6c
retry if init crashes
skourta Apr 22, 2025
9123534
check if datasetwork is ray actor
skourta Apr 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ dataset
!requirements.txt
*.yaml

workspace
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Tiramisu GNN Autoscheduler

This is a repository for the Tiramisu GNN Autoscheduler, a tool for training graph neural networks based reinforcement learning agents. The tool is built on top of the [Tiramisu](https://github.com/Tiramisu-Compiler/tiramisu), a polyhedral compiler for deep learning and scientific computing.

The project uses [TiraLib](https://github.com/Tiramisu-Compiler/TiraLib) as the backend for communications with the Tiramisu compiler. TiraLib is a Python library that provides an interface to write and run Tiramisu programs. It uses [TiralibCpp](https://github.com/skourta/TiraLibCpp) as a backend to compile and run the Tiramisu programs.

## Installation

To install the Tiramisu GNN Autoscheduler, you need to install the following dependencies:
- [Tiramisu](https://github.com/Tiramisu-Compiler/tiramisu)
- [TiralibCpp](https://github.com/skourta/TiraLibCpp)
- [Poetry](https://python-poetry.org/)
After installing the dependencies, you can run the following commands to train an agent:

- Clone the repository:
```bash
git clone https://github.com/Tiramisu-Compiler/gnn_rl
```
- Install the requirements :
```bash
poetry install
```

- Create a config for training by copying the example config and make sure both the directory with the headers and library files for Tiramisu and TiraLibCpp are correctly added to the config:
```bash
cp config/config.yaml.temp config/config.yaml
```


- Create a TraLib config just like in the [example template](https://github.com/Tiramisu-Compiler/TiraLib/blob/main/config.yaml.example) and save it in the same directory as the config.yaml
```bash
wget https://raw.githubusercontent.com/Tiramisu-Compiler/TiraLib/main/config.yaml.example -O config/tiralib_config.yaml
```

- Use the `rl_exec_job.sh.temp` script as a template to create a script that will run the training job. Make sure to update the paths to the config files and the Tiramisu GNN Autoscheduler repository in the script.

- You can also run the training on the current machine by running the following command:
```bash
python train_ppo_gnn.py --num-nodes=$NBR_NODES --name$NAME_OF_TRAINING
```
222 changes: 70 additions & 152 deletions agent/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,76 @@
import re


def isl_to_write_matrix(isl_map):
comp_iterators_str = re.findall(r"\[(.*)\]\s*->", isl_map)[0]
buffer_iterators_str = re.findall(r"->\s*\w*\[(.*)\]", isl_map)[0]
buffer_iterators_str = re.sub(r"\w+'\s=", "", buffer_iterators_str)
comp_iter_names = re.findall(r"(?:\s*(\w+))+", comp_iterators_str)
buf_iter_names = re.findall(r"(?:\s*(\w+))+", buffer_iterators_str)
matrix = np.zeros([len(buf_iter_names), len(comp_iter_names) + 1])
for i, buf_iter in enumerate(buf_iter_names):
for j, comp_iter in enumerate(comp_iter_names):
if buf_iter == comp_iter:
matrix[i, j] = 1
break
return matrix


def iterators_to_vectors(annotations):
it_dict = {}
iter_vector_size = 718
size_of_comp_vector = 709
for it in annotations["iterators"]:
single_iter_vector = -np.ones(iter_vector_size)
single_iter_vector[0] = 0
single_iter_vector[-9:] = 0
# lower value
single_iter_vector[size_of_comp_vector + 1] = annotations["iterators"][it][
"lower_bound"
]
# upper value
single_iter_vector[size_of_comp_vector + 2] = annotations["iterators"][it][
"upper_bound"
]
it_dict[it] = single_iter_vector

return it_dict
def parse_isl_map(isl_map: str):
# Extract the computation and buffer parts
match = re.match(r".*{\s*(\w+)\[([^\]]+)\]\s*->\s*(\w+)\[([^\]]+)\]\s*}", isl_map)
if not match:
raise ValueError("Invalid ISL map format")

_ = match.group(1)
comp_iterators = match.group(2).split(",")
_ = match.group(3)
buffer_accesses = match.group(4).split(",")

# Strip any spaces around iterators or accesses
comp_iterators = [it.strip() for it in comp_iterators]
buffer_accesses = [access.strip() for access in buffer_accesses]

return comp_iterators, buffer_accesses


def extract_affine_coefficients(access, comp_iterators):
"""
Extract coefficients of iterators in the affine expression.
Example: access = 'i1 + 2*j2', comp_iterators = ['i1', 'j2', 'k']
Output: [1, 2, 0, 0] (coefficients for i1, j2, k, and scalar term)
"""
# remove all spaces
access = access.replace(" ", "")

# Initialize coefficients (including one for the scalar term)
coefficients = [0] * (len(comp_iterators) + 1)

# Match terms like '2*i1' or 'i1' or '-i1', and extract the coefficients
for i, it in enumerate(comp_iterators):
# Modify the pattern to match iterator names with numbers or underscores
term_pattern = re.compile(r"([+-]?\d*)\s*\*?\s*(" + re.escape(it) + r")(\b|$)")
match = term_pattern.search(access)
if match:
coeff = match.group(1)
coefficients[i] = (
int(coeff)
if coeff and coeff != "+" and coeff != "-"
else 1
if coeff == "" or coeff == "+"
else -1
)

# Handle the scalar part, which is just a constant not attached to any iterator
scalar_pattern = re.compile(
r"([+-]?\b\d+\b)(?![\*\w])"
) # Match isolated constants (scalars) not tied to iterators
scalar_match = scalar_pattern.search(access)

if scalar_match:
coefficients[-1] = int(scalar_match.group(1))

return coefficients


def isl_map_to_write_access_matrix(isl_map: str):
# Parse the ISL map
comp_iterators, buffer_accesses = parse_isl_map(isl_map)

# Initialize the access matrix (rows = buffer dimensions, columns = iterators + scalar)
access_matrix = []

# Process each buffer access and extract affine coefficients
for access in buffer_accesses:
row = extract_affine_coefficients(access, comp_iterators)
access_matrix.append(row)

return np.array(access_matrix)


def pad_access_matrix(access_matrix, max_depth):
Expand All @@ -53,121 +89,3 @@ def encode_data_type(data_type):
return [0, 1, 0]
elif data_type == "float64":
return [0, 0, 1]


def comps_to_vectors(annotations):
comp_vector_size = 718
max_depth = 5
dict_comp = {}
for comp in annotations["computations"]:
single_comp_vector = -np.ones(comp_vector_size)
# This means that this vector has data related to a computation and not an iterator
single_comp_vector[0] = 1
comp_dict = annotations["computations"][comp]
# This field represents the absolute order of execution of computations
single_comp_vector[1] = comp_dict["absolute_order"]
# a vector of one-hot encoding of possible 3 data-types
single_comp_vector[2:5] = encode_data_type(comp_dict["data_type"])
single_comp_vector[5] = +comp_dict["comp_is_reduction"]
# The write-to buffer id
single_comp_vector[6] = +comp_dict["write_buffer_id"]
# We add a vector of write access
write_matrix = isl_to_write_matrix(comp_dict["write_access_relation"])
padded_matrix = pad_access_matrix(write_matrix, max_depth).reshape(-1)
single_comp_vector[7 : 7 + padded_matrix.shape[0]] = padded_matrix
# We add vector of read access
for index, read_access_dict in enumerate(comp_dict["accesses"]):
read_access_matrix = pad_access_matrix(
np.array(read_access_dict["access_matrix"]), max_depth
).reshape(-1)
read_access_matrix = np.append(
read_access_matrix, +read_access_dict["access_is_reduction"]
)
read_access_matrix = np.append(
read_access_matrix, read_access_dict["buffer_id"] + 1
)
read_access_size = read_access_matrix.shape[0]
single_comp_vector[
49 + index * read_access_size : 49 + (index + 1) * read_access_size
] = read_access_matrix
dict_comp[comp] = single_comp_vector
return dict_comp


def build_graph(annotations):
it_vector_dict = iterators_to_vectors(annotations)
comp_vector_dict = comps_to_vectors(annotations)
it_index = {}
comp_index = {}
num_iterators = len(annotations["iterators"])
for i, it in enumerate(it_vector_dict):
it_index[it] = i
for i, comp in enumerate(comp_vector_dict):
comp_index[comp] = i

edge_index = []
node_feats = None

for it in annotations["iterators"]:
for child_it in annotations["iterators"][it]["child_iterators"]:
edge_index.append([it_index[it], it_index[child_it]])

for child_comp in annotations["iterators"][it]["computations_list"]:
edge_index.append([it_index[it], num_iterators + comp_index[child_comp]])
node_feats = np.stack(
[
*[arr for arr in it_vector_dict.values()],
*[arr for arr in comp_vector_dict.values()],
],
)
return node_feats, np.array(edge_index), it_index, comp_index


def apply_parallelization(iterator, node_feats, it_index):
index = it_index[iterator]
node_feats[index][-6] = 1


def apply_reversal(iterator, node_feats, it_index):
index = it_index[iterator]
node_feats[index][-5] = 1


def apply_unrolling(iterator, unrolling_factor, node_feats, it_index):
index = it_index[iterator]
node_feats[index][-4] = unrolling_factor


def apply_tiling(iterators, tile_sizes, node_feats, it_index):
for it, tile in zip(iterators, tile_sizes):
index = it_index[it]
node_feats[index][-3] = tile


def apply_skewing(iterators, skewing_factors, node_feats, it_index):
for it in iterators:
index = it_index[it]
node_feats[index][-2:] = skewing_factors


def apply_interchange(iterators, edge_index, it_index):
it1, it2 = it_index[iterators[0]], it_index[iterators[1]]
for edge in edge_index:
if edge[0] == it1:
edge[0] = it2
elif edge[0] == it2:
edge[0] = it1
if edge[1] == it1:
edge[1] = it2
elif edge[1] == it2:
edge[1] = it1


def focus_on_iterators(iterators, node_feats, it_index):
# We reset the value for all the nodes
node_feats[: len(it_index), -9:-8] = 0
# We focus on the branches' iterators
for it in iterators:
index = it_index[it]
node_feats[index][-9] = 1
return node_feats
4 changes: 2 additions & 2 deletions agent/policy_value_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ def shared_layers(self, data):
def forward(self, data, actions_mask=None, action=None):
weights = self.shared_layers(data)
logits = self.π(weights)
if actions_mask != None:
if actions_mask is not None:
logits = logits - actions_mask * 1e8
probs = Categorical(logits=logits)
if action == None:
if action is None:
action = probs.sample()
value = self.v(weights)
return action, probs.log_prob(action), probs.entropy(), value
Loading