From 7dbca12fcab491566d22e7e1d0c655500f2dc26c Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 28 Feb 2024 15:11:06 -0700 Subject: [PATCH 01/21] first throw at refactoring SamplingIterator --- src/gflownet/algo/trajectory_balance.py | 4 + src/gflownet/data/config.py | 6 + src/gflownet/data/data_source.py | 268 ++++++++++++++++++++++++ src/gflownet/data/sampling_iterator.py | 49 ++++- src/gflownet/tasks/seh_frag.py | 2 +- src/gflownet/trainer.py | 66 +++--- src/gflownet/utils/misc.py | 21 ++ 7 files changed, 390 insertions(+), 26 deletions(-) create mode 100644 src/gflownet/data/data_source.py diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 75e5471f..308983ab 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -125,6 +125,7 @@ def __init__( # instead give "ABC...Z" as a single input, but grab the logits at every timestep. Only works if using something # like a transformer with causal self-attention. self.model_is_autoregressive = False + self.random_action_prob = [cfg.algo.train_random_action_prob, cfg.algo.valid_random_action_prob] self.graph_sampler = GraphSampler( ctx, @@ -140,6 +141,9 @@ def __init__( self._subtb_max_len = self.global_cfg.algo.max_len + 2 self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info? + def set_is_eval(self, is_eval: bool): + self.is_eval = is_eval + def create_training_data_from_own_samples( self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float ): diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py index fab5d036..5c5a9c84 100644 --- a/src/gflownet/data/config.py +++ b/src/gflownet/data/config.py @@ -16,9 +16,15 @@ class ReplayConfig: The number of samples to collect before starting to sample from the replay buffer hindsight_ratio : float The ratio of hindsight samples within a batch + batch_size : Optional[int] + The batch size for sampling from the replay buffer, defaults to the online batch size + replaces_online_data : bool + Whether to replace online data with samples from the replay buffer """ use: bool = False capacity: Optional[int] = None warmup: Optional[int] = None hindsight_ratio: float = 0 + batch_size: Optional[int] = None + replaces_online_data: bool = True diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py new file mode 100644 index 00000000..e6354e40 --- /dev/null +++ b/src/gflownet/data/data_source.py @@ -0,0 +1,268 @@ +import numpy as np +import torch +from gflownet.data.replay_buffer import ReplayBuffer +from typing import Any, Callable, Dict, List, NewType, Optional, Protocol, Tuple, Generator +from torch.utils.data import Dataset, IterableDataset + +from gflownet.config import Config +from gflownet.utils.misc import get_worker_rng +from gflownet.envs.graph_building_env import GraphBuildingEnvContext +#from gflownet.trainer import GFNAlgorithm, GFNTask + + +def cycle_call(it): + while True: + for i in it(): + yield i + + +class DataSource(IterableDataset): + def __init__( + self, + cfg: Config, + ctx: GraphBuildingEnvContext, + algo, #: GFNAlgorithm, + task, #: GFNTask, # TODO: this will cause a circular import + dev: torch.device, + replay_buffer: Optional[ReplayBuffer] = None, + is_algo_eval: bool = False, + start_at_step: int = 0, + ): + """A DataSource mixes multiple iterators into one. These are created with do_* methods.""" + self.iterators: List[Generator] = [] + self.cfg = cfg + self.ctx = ctx + self.algo = algo + self.task = task + self.dev = dev + self.replay_buffer = replay_buffer + self.is_algo_eval = is_algo_eval + + self.global_step_count = torch.zeros(1, dtype=torch.int64) + start_at_step + self.global_step_count.share_memory_() + self.global_step_count_lock = torch.multiprocessing.Lock() + self.current_iter = start_at_step + self.sampling_hooks: List[Callable] = [] + self.active = True + + def add_sampling_hook(self, hook: Callable): + """Add a hook that is called when sampling new trajectories. + + The hook should take a list of trajectories as input. + The hook will not be called on trajectories that are sampled from the replay buffer or dataset. + """ + self.sampling_hooks.append(hook) + return self + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + self.rng = get_worker_rng() + its = [i() for i in self.iterators] + self.algo.set_is_eval(self.is_algo_eval) + while True: + with self.global_step_count_lock: + self.current_iter = self.global_step_count.item() + self.global_step_count += 1 + iterator_outputs = [next(i, None) for i in its] + if any(i is None for i in iterator_outputs): + if not all(i is None for i in iterator_outputs): + raise ValueError("Some iterators are done, but not all. You may be mixing incompatible iterators.") + break + traj_lists, batch_infos = zip(*iterator_outputs) + trajs = sum(traj_lists, []) + # Merge all the dicts into one + batch_info = {} + for d in batch_infos: + batch_info.update(d) + yield self.create_batch(trajs, batch_info) + + def do_sample_model(self, model, num_samples, keep_samples_in_batch=True): + if not keep_samples_in_batch: + assert self.replay_buffer is not None, "Throwing away samples without a replay buffer" + + def iterator(): + while self.active: + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = self.task.sample_cond_info(num_samples, t) + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs + self.compute_properties(trajs, mark_as_online=True) + self.compute_log_rewards(trajs) + self.send_to_replay(trajs) # This is a no-op if there is no replay buffer + batch_info = self.call_sampling_hooks(trajs) + yield (trajs, batch_info) if keep_samples_in_batch else ([], {}) + + self.iterators.append(iterator) + return self + + def do_sample_replay(self, num_samples): + def iterator(): + while self.active: + trajs = self.replay_buffer.sample(num_samples) + self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 + yield trajs, {} + show_type(iterator) + self.iterators.append(iterator) + return self + + def do_dataset_in_order(self, data, num_samples, backwards_model): + def iterator(): + for idcs in self.iterate_indices(num_samples): + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = self.task.sample_conditional_information(num_samples, t) + objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs + self.set_traj_props(trajs, props) + self.compute_log_rewards(trajs) + yield trajs, {} + + self.iterators.append(iterator) + return self + + def do_conditionals_dataset_in_order(self, data, num_samples, model): + def iterator(): + for idcs in self.iterate_indices(len(data), num_samples): + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = torch.stack([data[i] for i in idcs]) + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + self.compute_properties(trajs, mark_as_online=True) + self.compute_log_rewards(trajs) + self.send_to_replay(trajs) # This is a no-op if there is no replay buffer + batch_info = self.call_sampling_hooks(trajs) + yield trajs, batch_info + + self.iterators.append(iterator) + return self + + def do_sample_dataset(self, data, num_samples, backwards_model): + def iterator(): + while self.active: + idcs = self.sample_idcs(len(data), num_samples) + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = self.task.sample_conditional_information(num_samples, t) + objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs + self.set_traj_props(trajs, props) + self.compute_log_rewards(trajs) + yield trajs, {} + + self.iterators.append(iterator) + return self + + def call_sampling_hooks(self, trajs): + batch_info = {} + # TODO: just pass trajs to the hooks and deprecate passing all those arguments + flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + # convert cond_info back to a dict + cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs["cond_info"][0]} + log_rewards = torch.stack([t["log_reward"] for t in trajs]) + for hook in self.sampling_hooks: + batch_info.update(hook(trajs, log_rewards, flat_rewards, cond_info)) + + def create_batch(self, trajs, batch_info): + ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) + log_rewards = torch.stack([t["log_reward"] for t in trajs]) + batch = self.algo.construct_batch(trajs, ci, log_rewards) + batch.num_online = sum(t["is_online"] for t in trajs) + batch.num_offline = len(trajs) - batch.num_online + batch.extra_info = batch_info + batch.preferences = torch.stack([t["preference"] for t in trajs]) + batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) + + if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute + log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] + batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) + batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) + # TODO: find code that depends on batch.flat_rewards and deprecate it + return batch + + def compute_properties(self, trajs, mark_as_online=False): + """Sets trajs' flat_rewards and is_valid keys by querying the task.""" + # TODO: refactor flat_rewards into properties + valid_idcs = torch.tensor([i for i in range(len(trajs)) if trajs[i].get("is_valid", True)]).long() + # fetch the valid trajectories endpoints + objs = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] + # ask the task to compute their reward + # TODO: it's really weird that the task is responsible for this and returns a flat_rewards + # tensor whose first dimension is possibly not the same as the output??? + flat_rewards, m_is_valid = self.task.compute_flat_rewards(objs) + assert flat_rewards.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" + # The task may decide some of the objs are invalid, we have to again filter those + valid_idcs = valid_idcs[m_is_valid] + all_fr = torch.zeros((len(trajs), flat_rewards.shape[1])) + all_fr[valid_idcs] = flat_rewards + for i in range(len(trajs)): + trajs[i]["flat_rewards"] = all_fr[i] + trajs[i]["is_online"] = mark_as_online + # Override the is_valid key in case the task made some objs invalid + for i in valid_idcs: + trajs[i]["is_valid"] = True + + def compute_log_rewards(self, trajs): + """Sets trajs' log_reward key by querying the task.""" + flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} + log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) + for i in range(len(trajs)): + trajs[i]["log_reward"] = log_rewards[i] if trajs[i]["is_valid"] else self.cfg.algo.illegal_action_logreward + + def send_to_replay(self, trajs): + if self.replay_buffer is not None: + for t in trajs: + self.replay_buffer.push(t, t["log_rewards"], t["flat_rewards"], t["cond_info"], t["is_valid"]) + + def set_traj_cond_info(self, trajs, cond_info): + for i in range(len(trajs)): + trajs[i]["cond_info"] = {k: cond_info[k][i] for k in cond_info} + + def set_traj_props(self, trajs, props): + for i in range(len(trajs)): + trajs[i]["flat_rewards"] = props[i] # TODO: refactor + + def relabel_in_hindsight(self, trajs): + if self.cfg.replay.hindsight_ratio == 0: + return + assert hasattr( + self.task, "relabel_condinfo_and_logrewards" + ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" + # samples indexes of trajectories without repeats + hindsight_idxs = torch.randperm(len(trajs))[: int(len(trajs) * self.cfg.replay.hindsight_ratio)] + log_rewards = torch.stack([t["log_reward"] for t in trajs]) + flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( + cond_info, log_rewards, flat_rewards, hindsight_idxs + ) + # TODO: This seems wrong, since we haven't recomputed is_valid + # log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + + def sample_idcs(self, n, num_samples): + return self.rng.choice(n, num_samples, replace=False) + + def iterate_indices(self, n, num_samples): + worker_info = torch.utils.data.get_worker_info() + if n == 0: + # Should we be raising an error here? warning? + yield np.arange(0, 0) + return + + if worker_info is None: # no multi-processing + start, end, wid = 0, n, -1 + else: # split the data into chunks (per-worker) + nw = worker_info.num_workers + wid = worker_info.id + start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) + + if end - start <= num_samples: + yield np.arange(start, end) + return + for i in range(start, end - num_samples, num_samples): + yield np.arange(i, i + num_samples) + if i + num_samples < end: + yield np.arange(i + num_samples, end) \ No newline at end of file diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index 9795467e..cf15f66c 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -198,7 +198,7 @@ def __iter__(self): num_online = num_offline num_offline = 0 cond_info = self.task.encode_conditional_information( - steer_info=torch.stack([self.data[i] for i in idcs]) + steer_info=torch.stack([self.data[i] for i in idcs]) # This is sus, what's going on here? ) trajs, flat_rewards = [], [] @@ -400,6 +400,53 @@ def log_generated(self, trajs, rewards, flat_rewards, cond_info): self.log.insert_many(data, data_labels) +class SQLiteLogHook: + def __init__(self, log_dir, ctx) -> None: + self.log = None # Only initialized in __call__, which will occur inside the worker + self.log_dir = log_dir + self.ctx = ctx + self.data_labels = None + + def __call__(self, trajs, rewards, flat_rewards, cond_info): + if self.log is None: + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + os.makedirs(self.log_dir, exist_ok=True) + self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" + self.log = SQLiteLog() + self.log.connect(self.log_path) + + if hasattr(self.ctx, "object_to_log_repr"): + mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] + else: + mols = [""] * len(trajs) + + flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() + rewards = rewards.data.numpy().tolist() + preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] + + data = [ + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] + for i in range(len(trajs)) + ] + if self.data_labels is None: + self.data_labels = ( + ["smi", "r"] + + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + + [f"ci_{k}" for k in logged_keys] + ) + + self.log.insert_many(data, self.data_labels) + + class SQLiteLog: def __init__(self, timeout=300): """Creates a log instance, but does not connect it to any db.""" diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 91d65818..fcce6ca6 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -206,7 +206,7 @@ def main(): "device": "cuda" if torch.cuda.is_available() else "cpu", "overwrite_existing_exp": True, "num_training_steps": 10_000, - "num_workers": 8, + "num_workers": 0, "opt": { "lr_decay": 20000, }, diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index e60d742e..afeceafe 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -16,8 +16,9 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from gflownet.data.data_source import DataSource from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.data.sampling_iterator import SamplingIterator +from gflownet.data.sampling_iterator import SamplingIterator, SQLiteLogHook from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger @@ -35,9 +36,11 @@ class GFNAlgorithm: updates: int = 0 + global_cfg: Config + is_eval: bool = False def step(self): - self.updates += 1 + self.updates += 1 # This isn't used anywhere? def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 @@ -62,6 +65,13 @@ def compute_batch_losses( """ raise NotImplementedError() + def get_random_action_prob(self, it: int): + if self.is_eval: + return self.global_cfg.algo.valid_random_action_prob + if it < self.global_cfg.algo.train_det_after or self.global_cfg.algo.train_det_after is None: + return self.global_cfg.algo.train_random_action_prob + return 0 + class GFNTask: def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: @@ -188,14 +198,14 @@ def _wrap_for_mp(self, obj, send_to_device=False): if send_to_device: obj.to(self.device) if self.cfg.num_workers > 0 and obj is not None: - wapper = mp_object_wrapper( + wrapper = mp_object_wrapper( obj, self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, ) - self.to_terminate.append(wapper.terminate) - return wapper.placeholder, torch.device("cpu") + self.to_terminate.append(wrapper.terminate) + return wrapper.placeholder, torch.device("cpu") else: return obj, self.device @@ -203,28 +213,36 @@ def build_callbacks(self): return {} def build_training_data_loader(self) -> DataLoader: + # Since the model may be used by a worker in a different process, we need to wrap it. + # The device `dev` returned here is the device that the worker will use to interact with the model; + # normally, if the main process has the model on 'cuda', this will simply be 'cpu' (since workers + # don't have CUDA access). + # See implementation_nodes.md for more details. model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) - iterator = SamplingIterator( - self.training_data, - model, - self.ctx, - self.algo, - self.task, - dev, - batch_size=self.cfg.algo.global_batch_size, - illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - replay_buffer=replay_buffer, - ratio=self.cfg.algo.offline_ratio, - log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), - random_action_prob=self.cfg.algo.train_random_action_prob, - det_after=self.cfg.algo.train_det_after, - hindsight_ratio=self.cfg.replay.hindsight_ratio, - ) + + n_drawn = int(self.cfg.algo.global_batch_size * (1 - self.cfg.algo.offline_ratio)) + n_replayed = n_drawn if self.cfg.replay.batch_size is None else self.cfg.replay.batch_size + n_from_dataset = self.cfg.algo.global_batch_size - n_drawn + + src = DataSource(self.cfg, self.ctx, self.algo, self.task, dev, replay_buffer=replay_buffer) + if n_from_dataset: + src.do_dataset_in_order(self.training_data, n_from_dataset, backwards_model=model) + if n_drawn: + # If we are using a replay buffer, we can choose to keep the new samples in the minibatch, or just + # send them to the replay and train only on replay samples. + keep_samples_in_batch = not self.cfg.replay.use or not self.cfg.replay.replaces_online_data + src.do_sample_model(model, n_drawn, keep_samples_in_batch) + if n_replayed and replay_buffer is not None: + src.do_sample_replay(n_replayed) + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "train"), self.ctx)) for hook in self.sampling_hooks: - iterator.add_log_hook(hook) + src.add_sampling_hook(hook) + # TODO: We could just have a build_training_data_source method that returns a DataSource + # All the other build_* methods do the same DataLoader setup return torch.utils.data.DataLoader( - iterator, + src, batch_size=None, num_workers=self.cfg.num_workers, persistent_workers=self.cfg.num_workers > 0, @@ -296,7 +314,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: if not torch.isfinite(loss): raise ValueError("loss is not finite") step_info = self.step(loss) - self.algo.step() + self.algo.step() # This also isn't used anywhere? if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") except ValueError as e: diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index d8b350b2..f65a83d5 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -1,6 +1,9 @@ import logging import sys +import numpy as np +import torch + def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): logger = logging.getLogger(name) @@ -21,3 +24,21 @@ def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHand logger.addHandler(handler) return logger + + +_worker_rngs = {} +_worker_rng_seed = [142857] + + +def get_worker_rng(): + worker_info = torch.utils.data.get_worker_info() + wid = worker_info.id if worker_info is not None else 0 + if wid not in _worker_rngs: + _worker_rngs[wid] = np.random.RandomState(_worker_rng_seed[0] + wid) + return _worker_rngs[wid] + + +def set_worker_rng_seed(seed): + _worker_rng_seed[0] = seed + for wid in _worker_rngs: + _worker_rngs[wid].seed(seed + wid) From dfba1ca478ee8a7b29cdb54753f25bc8ca2c7000 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 29 Feb 2024 09:59:00 -0700 Subject: [PATCH 02/21] changed all iterators to DataSource --- src/gflownet/algo/soft_q_learning.py | 3 +- src/gflownet/algo/trajectory_balance.py | 9 +- src/gflownet/config.py | 1 + src/gflownet/data/data_source.py | 92 ++++++++++++++------ src/gflownet/data/sampling_iterator.py | 1 + src/gflownet/envs/frag_mol_env.py | 1 - src/gflownet/envs/mol_building_env.py | 1 - src/gflownet/online_trainer.py | 1 + src/gflownet/tasks/seh_frag.py | 11 ++- src/gflownet/trainer.py | 109 ++++++++++-------------- src/gflownet/utils/misc.py | 12 +++ 11 files changed, 142 insertions(+), 99 deletions(-) diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index 1e3f1146..378d0b7e 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -8,6 +8,7 @@ from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory +from gflownet.utils.misc import get_worker_device class SoftQLearning: @@ -75,7 +76,7 @@ def create_training_data_from_own_samples( - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx """ - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) return data diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 308983ab..fcd171b4 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -22,6 +22,7 @@ generate_forward_trajectory, ) from gflownet.trainer import GFNAlgorithm +from gflownet.utils.misc import get_worker_device def shift_right(x: torch.Tensor, z=0): @@ -139,7 +140,7 @@ def __init__( ) if self.cfg.variant == TBVariant.SubTB1: self._subtb_max_len = self.global_cfg.algo.max_len + 2 - self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info? + self._init_subtb(get_worker_device()) def set_is_eval(self, is_eval: bool): self.is_eval = is_eval @@ -171,7 +172,7 @@ def create_training_data_from_own_samples( - loss: predicted loss (if bootstrapping) - is_valid: is the generated graph valid according to the env & ctx """ - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) logZ_pred = model.logZ(cond_info) @@ -206,7 +207,7 @@ def create_training_data_from_graphs( """ if self.cfg.do_sample_p_b: assert model is not None and cond_info is not None and random_action_prob is not None - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) return self.graph_sampler.sample_backward_from_graphs( graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob @@ -217,7 +218,7 @@ def create_training_data_from_graphs( self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) for gp, _ in traj["traj"][1:] ] + [1] - traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device) + traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(get_worker_device()) traj["result"] = traj["traj"][-1][0] if self.cfg.do_parameterize_p_b: traj["bck_a"] = [GraphAction(GraphActionType.Stop)] + [self.env.reverse(g, a) for g, a in traj["traj"]] diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 782b4ff4..73ed6f15 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -94,6 +94,7 @@ class Config: print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None + num_validation_gen_steps: Optional[int] = None num_training_steps: int = 10_000 num_workers: int = 0 hostname: Optional[str] = None diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index e6354e40..5cc79b61 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -1,13 +1,16 @@ +import warnings +from typing import Callable, Generator, List, Optional + import numpy as np import torch -from gflownet.data.replay_buffer import ReplayBuffer -from typing import Any, Callable, Dict, List, NewType, Optional, Protocol, Tuple, Generator -from torch.utils.data import Dataset, IterableDataset +from torch.utils.data import IterableDataset from gflownet.config import Config -from gflownet.utils.misc import get_worker_rng +from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphBuildingEnvContext -#from gflownet.trainer import GFNAlgorithm, GFNTask +from gflownet.utils.misc import get_worker_rng + +# from gflownet.trainer import GFNAlgorithm, GFNTask def cycle_call(it): @@ -21,9 +24,8 @@ def __init__( self, cfg: Config, ctx: GraphBuildingEnvContext, - algo, #: GFNAlgorithm, - task, #: GFNTask, # TODO: this will cause a circular import - dev: torch.device, + algo, #: GFNAlgorithm, + task, #: GFNTask, # TODO: this will cause a circular import replay_buffer: Optional[ReplayBuffer] = None, is_algo_eval: bool = False, start_at_step: int = 0, @@ -34,16 +36,15 @@ def __init__( self.ctx = ctx self.algo = algo self.task = task - self.dev = dev self.replay_buffer = replay_buffer self.is_algo_eval = is_algo_eval + self.sampling_hooks: List[Callable] = [] + self.active = True self.global_step_count = torch.zeros(1, dtype=torch.int64) + start_at_step self.global_step_count.share_memory_() self.global_step_count_lock = torch.multiprocessing.Lock() self.current_iter = start_at_step - self.sampling_hooks: List[Callable] = [] - self.active = True def add_sampling_hook(self, hook: Callable): """Add a hook that is called when sampling new trajectories. @@ -67,8 +68,10 @@ def __iter__(self): iterator_outputs = [next(i, None) for i in its] if any(i is None for i in iterator_outputs): if not all(i is None for i in iterator_outputs): - raise ValueError("Some iterators are done, but not all. You may be mixing incompatible iterators.") - break + warnings.warn("Some iterators are done, but not all. You may be mixing incompatible iterators.") + iterator_outputs = [i for i in iterator_outputs if i is not None] + else: + break traj_lists, batch_infos = zip(*iterator_outputs) trajs = sum(traj_lists, []) # Merge all the dicts into one @@ -85,8 +88,9 @@ def iterator(): while self.active: t = self.current_iter p = self.algo.get_random_action_prob(t) - cond_info = self.task.sample_cond_info(num_samples, t) - trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + cond_info = self.task.sample_conditional_information(num_samples, t) + # TODO: in the cond info refactor, pass the whole thing instead of just the encoding + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) @@ -97,13 +101,43 @@ def iterator(): self.iterators.append(iterator) return self + def do_sample_model_n_times(self, model, num_samples_per_batch, num_total): + total = torch.zeros(1, dtype=torch.int64) + total.share_memory_() + total_lock = torch.multiprocessing.Lock() + total_barrier = torch.multiprocessing.Barrier(max(1, self.cfg.num_workers)) + + def iterator(): + while self.active: + with total_lock: + n_so_far = total.item() + n_this_time = min(num_total - n_so_far, num_samples_per_batch) + total[:] += n_this_time + if n_this_time == 0: + break + t = self.current_iter + p = self.algo.get_random_action_prob(t) + cond_info = self.task.sample_conditional_information(n_this_time, t) + # TODO: in the cond info refactor, pass the whole thing instead of just the encoding + trajs = self.algo.create_training_data_from_own_samples(model, n_this_time, cond_info["encoding"], p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs + self.compute_properties(trajs, mark_as_online=True) + self.compute_log_rewards(trajs) + batch_info = self.call_sampling_hooks(trajs) + yield trajs, batch_info + total_barrier.wait() # Wait for all workers to finish before resetting the counter + total[:] = 0 + + self.iterators.append(iterator) + return self + def do_sample_replay(self, num_samples): def iterator(): while self.active: trajs = self.replay_buffer.sample(num_samples) self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 yield trajs, {} - show_type(iterator) + self.iterators.append(iterator) return self @@ -114,7 +148,7 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) - trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -129,7 +163,7 @@ def iterator(): t = self.current_iter p = self.algo.get_random_action_prob(t) cond_info = torch.stack([data[i] for i in idcs]) - trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info, p) + trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) self.send_to_replay(trajs) # This is a no-op if there is no replay buffer @@ -147,7 +181,7 @@ def iterator(): p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) objs, props = map(list, zip(*[data[i] for i in idcs])) if len(idcs) else ([], []) - trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info, p) + trajs = self.algo.create_training_data_from_graphs(objs, backwards_model, cond_info["encoding"], p) self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.set_traj_props(trajs, props) self.compute_log_rewards(trajs) @@ -161,10 +195,11 @@ def call_sampling_hooks(self, trajs): # TODO: just pass trajs to the hooks and deprecate passing all those arguments flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) # convert cond_info back to a dict - cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs["cond_info"][0]} + cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = torch.stack([t["log_reward"] for t in trajs]) for hook in self.sampling_hooks: batch_info.update(hook(trajs, log_rewards, flat_rewards, cond_info)) + return batch_info def create_batch(self, trajs, batch_info): ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) @@ -173,8 +208,10 @@ def create_batch(self, trajs, batch_info): batch.num_online = sum(t["is_online"] for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info - batch.preferences = torch.stack([t["preference"] for t in trajs]) - batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) + if "preferences" in trajs[0]: + batch.preferences = torch.stack([t["preferences"] for t in trajs]) + if "focus_dir" in trajs[0]: + batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] @@ -230,12 +267,13 @@ def relabel_in_hindsight(self, trajs): if self.cfg.replay.hindsight_ratio == 0: return assert hasattr( - self.task, "relabel_condinfo_and_logrewards" - ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" + self.task, "relabel_condinfo_and_logrewards" + ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" # samples indexes of trajectories without repeats hindsight_idxs = torch.randperm(len(trajs))[: int(len(trajs) * self.cfg.replay.hindsight_ratio)] log_rewards = torch.stack([t["log_reward"] for t in trajs]) flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( cond_info, log_rewards, flat_rewards, hindsight_idxs ) @@ -251,18 +289,18 @@ def iterate_indices(self, n, num_samples): # Should we be raising an error here? warning? yield np.arange(0, 0) return - + if worker_info is None: # no multi-processing start, end, wid = 0, n, -1 else: # split the data into chunks (per-worker) nw = worker_info.num_workers wid = worker_info.id start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) - + if end - start <= num_samples: yield np.arange(start, end) return for i in range(start, end - num_samples, num_samples): yield np.arange(i, i + num_samples) if i + num_samples < end: - yield np.arange(i + num_samples, end) \ No newline at end of file + yield np.arange(i + num_samples, end) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index cf15f66c..04ff0ebe 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -445,6 +445,7 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): ) self.log.insert_many(data, self.data_labels) + return {} class SQLiteLog: diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index bab9506b..daf9f99f 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -87,7 +87,6 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu GraphActionType.RemoveNode, GraphActionType.RemoveEdgeAttr, ] - self.device = torch.device("cpu") self.n_counter = NCounter() self.sorted_frags = sorted(list(enumerate(self.frags_mol)), key=lambda x: -x[1].GetNumAtoms()) diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 20c05586..5e43dd0b 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -157,7 +157,6 @@ def __init__( GraphActionType.RemoveEdge, GraphActionType.RemoveEdgeAttr, ] - self.device = torch.device("cpu") def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index edda9c79..2e59304f 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -6,6 +6,7 @@ import torch from omegaconf import OmegaConf from torch import Tensor +from torch.utils.data import DataLoader from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.flow_matching import FlowMatching diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 2d56f213..54adcc7c 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -158,6 +158,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.train_random_action_prob = 0.0 cfg.algo.valid_random_action_prob = 0.0 cfg.algo.valid_offline_ratio = 0 + cfg.num_validation_gen_steps = 10 cfg.algo.tb.epsilon = None cfg.algo.tb.bootstrap_own_reward = False cfg.algo.tb.Z_learning_rate = 1e-3 @@ -199,13 +200,17 @@ def setup(self): def main(): """Example of how this model can be run.""" + import datetime + config = init_empty(Config()) config.print_every = 1 - config.log_dir = "./logs/debug_run_seh_frag_pb" + config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" config.device = "cuda" if torch.cuda.is_available() else "cpu" config.overwrite_existing_exp = True - config.num_training_steps = 10_000 - config.num_workers = 0 + config.num_training_steps = 1_00 + config.validate_every = 20 + config.num_final_gen_steps = 10 + config.num_workers = 8 config.opt.lr_decay = 20_000 config.algo.sampling_tau = 0.99 config.algo.offline_ratio = 0.0 diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index d903821b..17f38aef 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -19,10 +19,10 @@ from gflownet.data.data_source import DataSource from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.data.sampling_iterator import SamplingIterator, SQLiteLogHook +from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch -from gflownet.utils.misc import create_logger +from gflownet.utils.misc import create_logger, set_main_process_device from gflownet.utils.multiprocessing_proxy import mp_object_wrapper from .config import Config @@ -69,7 +69,7 @@ def compute_batch_losses( def get_random_action_prob(self, it: int): if self.is_eval: return self.global_cfg.algo.valid_random_action_prob - if it < self.global_cfg.algo.train_det_after or self.global_cfg.algo.train_det_after is None: + if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after: return self.global_cfg.algo.train_random_action_prob return 0 @@ -150,9 +150,10 @@ def __init__(self, config: Config, print_config=True): assert isinstance(self.default_cfg, Config) and isinstance( config, Config ) # make sure the config is a Config object, and not the Config class itself - self.cfg = OmegaConf.merge(self.default_cfg, config) + self.cfg: Config = OmegaConf.merge(self.default_cfg, config) self.device = torch.device(self.cfg.device) + set_main_process_device(self.device) # Print the loss every `self.print_every` iterations self.print_every = self.cfg.print_every # These hooks allow us to compute extra quantities when sampling data @@ -223,6 +224,15 @@ def _wrap_for_mp(self, obj, send_to_device=False): def build_callbacks(self): return {} + def _make_data_loader(self, src): + return torch.utils.data.DataLoader( + src, + batch_size=None, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + prefetch_factor=1 if self.cfg.num_workers else None, + ) + def build_training_data_loader(self) -> DataLoader: # Since the model may be used by a worker in a different process, we need to wrap it. # The device `dev` returned here is the device that the worker will use to interact with the model; @@ -236,7 +246,7 @@ def build_training_data_loader(self) -> DataLoader: n_replayed = n_drawn if self.cfg.replay.batch_size is None else self.cfg.replay.batch_size n_from_dataset = self.cfg.algo.global_batch_size - n_drawn - src = DataSource(self.cfg, self.ctx, self.algo, self.task, dev, replay_buffer=replay_buffer) + src = DataSource(self.cfg, self.ctx, self.algo, self.task, replay_buffer=replay_buffer) if n_from_dataset: src.do_dataset_in_order(self.training_data, n_from_dataset, backwards_model=model) if n_drawn: @@ -250,70 +260,45 @@ def build_training_data_loader(self) -> DataLoader: src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "train"), self.ctx)) for hook in self.sampling_hooks: src.add_sampling_hook(hook) - # TODO: We could just have a build_training_data_source method that returns a DataSource - # All the other build_* methods do the same DataLoader setup - return torch.utils.data.DataLoader( - src, - batch_size=None, - num_workers=self.cfg.num_workers, - persistent_workers=self.cfg.num_workers > 0, - prefetch_factor=1 if self.cfg.num_workers else None, - ) + return self._make_data_loader(src) def build_validation_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) - iterator = SamplingIterator( - self.test_data, - model, - self.ctx, - self.algo, - self.task, - dev, - batch_size=self.cfg.algo.global_batch_size, - illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - ratio=self.cfg.algo.valid_offline_ratio, - log_dir=str(pathlib.Path(self.cfg.log_dir) / "valid"), - sample_cond_info=self.cfg.cond.valid_sample_cond_info, - stream=False, - random_action_prob=self.cfg.algo.valid_random_action_prob, - ) + # TODO: we're changing the default, make sure anything that is using test data is adjusted + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + n_drawn = int(self.cfg.algo.global_batch_size * (1 - self.cfg.algo.valid_offline_ratio)) + n_from_dataset = self.cfg.algo.global_batch_size - n_drawn + + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + if n_from_dataset: + src.do_dataset_in_order(self.test_data, n_from_dataset, backwards_model=model) + if n_drawn: + assert self.cfg.num_validation_gen_steps is not None + # TODO: might be better to change total steps to total trajectories drawn + src.do_sample_model_n_times(model, n_drawn, num_total=self.cfg.num_validation_gen_steps * n_drawn) + + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) for hook in self.valid_sampling_hooks: - iterator.add_log_hook(hook) - return torch.utils.data.DataLoader( - iterator, - batch_size=None, - num_workers=self.cfg.num_workers, - persistent_workers=self.cfg.num_workers > 0, - prefetch_factor=1 if self.cfg.num_workers else None, - ) + src.add_sampling_hook(hook) + return self._make_data_loader(src) def build_final_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) - iterator = SamplingIterator( - self.training_data, - model, - self.ctx, - self.algo, - self.task, - dev, - batch_size=self.cfg.algo.global_batch_size, - illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - replay_buffer=None, - ratio=0.0, - log_dir=os.path.join(self.cfg.log_dir, "final"), - random_action_prob=0.0, - hindsight_ratio=0.0, - init_train_iter=self.cfg.num_training_steps, - ) + model, dev = self._wrap_for_mp(self.model, send_to_device=True) + # TODO: we're changing the default, make sure anything that is using test data is adjusted + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + n_drawn = int(self.cfg.algo.global_batch_size) + + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + assert self.cfg.num_final_gen_steps is not None + # TODO: might be better to change total steps to total trajectories drawn + src.do_sample_model_n_times(model, n_drawn, num_total=self.cfg.num_final_gen_steps * n_drawn) + + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "final"), self.ctx)) for hook in self.sampling_hooks: - iterator.add_log_hook(hook) - return torch.utils.data.DataLoader( - iterator, - batch_size=None, - num_workers=self.cfg.num_workers, - persistent_workers=self.cfg.num_workers > 0, - prefetch_factor=1 if self.cfg.num_workers else None, - ) + src.add_sampling_hook(hook) + return self._make_data_loader(src) def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: tick = time.time() diff --git a/src/gflownet/utils/misc.py b/src/gflownet/utils/misc.py index f65a83d5..7ec5bdba 100644 --- a/src/gflownet/utils/misc.py +++ b/src/gflownet/utils/misc.py @@ -8,6 +8,8 @@ def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): logger = logging.getLogger(name) logger.setLevel(loglevel) + while len([logger.removeHandler(i) for i in logger.handlers]): + pass # Remove all handlers (only useful when debugging) formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - {} - %(message)s".format(name), datefmt="%d/%m/%Y %H:%M:%S", @@ -28,6 +30,7 @@ def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHand _worker_rngs = {} _worker_rng_seed = [142857] +_main_process_device = [torch.device("cpu")] def get_worker_rng(): @@ -42,3 +45,12 @@ def set_worker_rng_seed(seed): _worker_rng_seed[0] = seed for wid in _worker_rngs: _worker_rngs[wid].seed(seed + wid) + + +def set_main_process_device(device): + _main_process_device[0] = device + + +def get_worker_device(): + worker_info = torch.utils.data.get_worker_info() + return _main_process_device[0] if worker_info is None else torch.device("cpu") From e5239fb48726bd27e9fc4cfca435e5e98510a6d2 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 29 Feb 2024 15:10:00 -0700 Subject: [PATCH 03/21] lots of little fixes, tested all tasks, better device management --- src/gflownet/algo/advantage_actor_critic.py | 4 +-- src/gflownet/algo/envelope_q_learning.py | 3 +- src/gflownet/data/data_source.py | 23 +++++++++++---- src/gflownet/online_trainer.py | 7 +++++ src/gflownet/tasks/qm9_moo.py | 31 ++++++++++++++++----- src/gflownet/tasks/seh_frag_moo.py | 23 +++++++++++++-- src/gflownet/trainer.py | 2 +- 7 files changed, 73 insertions(+), 20 deletions(-) diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 001e19d0..c6547b3d 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -6,7 +6,7 @@ from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory - +from gflownet.utils.misc import get_worker_device from .graph_sampling import GraphSampler @@ -79,7 +79,7 @@ def create_training_data_from_own_samples( - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx """ - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) return data diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 4d694ae2..7adfd68c 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -15,6 +15,7 @@ ) from gflownet.models.graph_transformer import GraphTransformer, mlp from gflownet.trainer import GFNTask +from gflownet.utils.misc import get_worker_device from .graph_sampling import GraphSampler @@ -233,7 +234,7 @@ def create_training_data_from_own_samples( - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx """ - dev = self.ctx.device + dev = get_worker_device() cond_info = cond_info.to(dev) data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) return data diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 5cc79b61..4df5dcd0 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -134,7 +134,7 @@ def iterator(): def do_sample_replay(self, num_samples): def iterator(): while self.active: - trajs = self.replay_buffer.sample(num_samples) + trajs, *_ = self.replay_buffer.sample(num_samples) self.relabel_in_hindsight(trajs) # This is a no-op if the hindsight ratio is 0 yield trajs, {} @@ -143,7 +143,7 @@ def iterator(): def do_dataset_in_order(self, data, num_samples, backwards_model): def iterator(): - for idcs in self.iterate_indices(num_samples): + for idcs in self.iterate_indices(len(data), num_samples): t = self.current_iter p = self.algo.get_random_action_prob(t) cond_info = self.task.sample_conditional_information(num_samples, t) @@ -162,11 +162,18 @@ def iterator(): for idcs in self.iterate_indices(len(data), num_samples): t = self.current_iter p = self.algo.get_random_action_prob(t) - cond_info = torch.stack([data[i] for i in idcs]) + # TODO: when we refactor cond_info, data[i] will probably be a dict? (or CondInfo objects) + # I'm also not a fan of encode_conditional_information, it assumes lots of things about what's passed to + # it and the state of the program (e.g. validation mode) + cond_info = self.task.encode_conditional_information(torch.stack([data[i] for i in idcs])) trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) + self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) self.send_to_replay(trajs) # This is a no-op if there is no replay buffer + # If we're using a dataset of preferences, the user/hooks may want to know the id of the preference + for i, j in zip(trajs, idcs): + i["data_idx"] = j batch_info = self.call_sampling_hooks(trajs) yield trajs, batch_info @@ -197,15 +204,16 @@ def call_sampling_hooks(self, trajs): # convert cond_info back to a dict cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = torch.stack([t["log_reward"] for t in trajs]) + rewards = torch.exp(log_rewards / (cond_info.get("beta", 1))) for hook in self.sampling_hooks: - batch_info.update(hook(trajs, log_rewards, flat_rewards, cond_info)) + batch_info.update(hook(trajs, rewards, flat_rewards, cond_info)) return batch_info def create_batch(self, trajs, batch_info): ci = torch.stack([t["cond_info"]["encoding"] for t in trajs]) log_rewards = torch.stack([t["log_reward"] for t in trajs]) batch = self.algo.construct_batch(trajs, ci, log_rewards) - batch.num_online = sum(t["is_online"] for t in trajs) + batch.num_online = sum(t.get("is_online", 0) for t in trajs) batch.num_offline = len(trajs) - batch.num_online batch.extra_info = batch_info if "preferences" in trajs[0]: @@ -247,8 +255,11 @@ def compute_log_rewards(self, trajs): flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) + min_r = torch.as_tensor(self.cfg.algo.illegal_action_logreward).float() for i in range(len(trajs)): - trajs[i]["log_reward"] = log_rewards[i] if trajs[i]["is_valid"] else self.cfg.algo.illegal_action_logreward + trajs[i]["log_reward"] = ( + log_rewards[i] if trajs[i].get("is_valid", True) else min_r + ) def send_to_replay(self, trajs): if self.replay_buffer is not None: diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 2e59304f..a73802d2 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -73,6 +73,8 @@ def setup(self): super().setup() self.offline_ratio = 0 self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None + self.sampling_hooks.append(AvgRewardHook()) + self.valid_sampling_hooks.append(AvgRewardHook()) # Separate Z parameters from non-Z to allow for LR decay on the former if hasattr(self.model, "logZ"): @@ -130,3 +132,8 @@ def step(self, loss: Tensor): for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) return {"grad_norm": g0, "grad_norm_clip": g1} + + +class AvgRewardHook: + def __call__(self, trajs, rewards, flat_rewards, extra_info): + return {"sampled_reward_avg": rewards.mean().item()} diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index b1dab870..d69e67f9 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -7,13 +7,15 @@ import torch_geometric.data as gd from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader import gflownet.models.mxmnet as mxmnet from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset +from gflownet.data.data_source import DataSource +from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks @@ -45,6 +47,7 @@ def __init__( self.cfg = cfg mcfg = self.cfg.task.qm9_moo self.objectives = cfg.task.qm9_moo.objectives + cfg.cond.moo.num_objectives = len(self.objectives) self.dataset = dataset if self.cfg.cond.focus_region.focus_type is not None: self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) @@ -224,14 +227,14 @@ def setup_model(self): num_emb=self.cfg.model.num_emb, num_layers=self.cfg.model.num_layers, num_heads=self.cfg.model.graph_transformer.num_heads, - num_objectives=len(self.cfg.task.seh_moo.objectives), + num_objectives=len(self.cfg.task.qm9_moo.objectives), ) else: super().setup_model() def setup(self): super().setup() - if self.cfg.task.seh_moo.online_pareto_front: + if self.cfg.task.qm9_moo.online_pareto_front: self.sampling_hooks.append( MultiObjectiveStatsHook( 256, @@ -245,7 +248,7 @@ def setup(self): self.to_terminate.append(self.sampling_hooks[-1].terminate) # instantiate preference and focus conditioning vectors for validation - n_obj = len(self.cfg.task.seh_moo.objectives) + n_obj = len(self.cfg.task.qm9_moo.objectives) cond_cfg = self.cfg.cond # making sure hyperparameters for preferences and focus regions are consistent @@ -263,7 +266,7 @@ def setup(self): if isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) > 1: n_valid = len(cond_cfg.focus_region.focus_type) else: - n_valid = self.cfg.task.seh_moo.n_valid + n_valid = self.cfg.task.qm9_moo.n_valid # preference vectors if cond_cfg.weighted_prefs.preference_type is None: @@ -298,8 +301,8 @@ def setup(self): else: valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) - self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats) + self._top_k_hook = TopKHook(10, self.cfg.task.qm9_moo.n_valid_repeats, n_valid) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.qm9_moo.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task @@ -324,6 +327,20 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} + def build_validation_data_loader(self) -> DataLoader: + model, dev = self._wrap_for_mp(self.model, send_to_device=True) + + n_from_dataset = self.cfg.algo.global_batch_size + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + src.do_conditionals_dataset_in_order(self.test_data, n_from_dataset, model) + + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) + for hook in self.valid_sampling_hooks: + src.add_sampling_hook(hook) + + return self._make_data_loader(src) + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: if self.task.focus_cond is not None: self.task.focus_cond.step_focus_model(batch, train_it) diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 1f8787d3..9ee69408 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -8,12 +8,14 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import Dataset +from torch.utils.data import Dataset, DataLoader from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext +from gflownet.data.data_source import DataSource +from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask from gflownet.trainer import FlatRewards, RewardScalar @@ -70,6 +72,7 @@ def __init__( self.cfg = cfg mcfg = self.cfg.task.seh_moo self.objectives = cfg.task.seh_moo.objectives + cfg.cond.moo.num_objectives = len(self.objectives) # This value is used by the focus_cond and pref_cond self.dataset = dataset if self.cfg.cond.focus_region.focus_type is not None: self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) @@ -350,6 +353,20 @@ def on_validation_end(self, metrics: Dict[str, Any]): return callback_dict + def build_validation_data_loader(self) -> DataLoader: + model, dev = self._wrap_for_mp(self.model, send_to_device=True) + + n_from_dataset = self.cfg.algo.global_batch_size + src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + src.do_conditionals_dataset_in_order(self.test_data, n_from_dataset, model) + + if self.cfg.log_dir: + src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) + for hook in self.valid_sampling_hooks: + src.add_sampling_hook(hook) + + return self._make_data_loader(src) + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: if self.task.focus_cond is not None: self.task.focus_cond.step_focus_model(batch, train_it) @@ -363,7 +380,7 @@ def _save_state(self, it): class RepeatedCondInfoDataset: def __init__(self, cond_info_vectors, repeat): - self.cond_info_vectors = cond_info_vectors + self.cond_info_vectors = torch.as_tensor(cond_info_vectors).float() self.repeat = repeat def __len__(self): @@ -371,7 +388,7 @@ def __len__(self): def __getitem__(self, idx): assert 0 <= idx < len(self) - return torch.tensor(self.cond_info_vectors[int(idx // self.repeat)]) + return self.cond_info_vectors[int(idx // self.repeat)] def main(): diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 17f38aef..a9409562 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -248,7 +248,7 @@ def build_training_data_loader(self) -> DataLoader: src = DataSource(self.cfg, self.ctx, self.algo, self.task, replay_buffer=replay_buffer) if n_from_dataset: - src.do_dataset_in_order(self.training_data, n_from_dataset, backwards_model=model) + src.do_sample_dataset(self.training_data, n_from_dataset, backwards_model=model) if n_drawn: # If we are using a replay buffer, we can choose to keep the new samples in the minibatch, or just # send them to the replay and train only on replay samples. From 43dfc2b0c010cca57edbf10f5d92d229ee2286ff Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 29 Feb 2024 17:51:51 -0700 Subject: [PATCH 04/21] style --- src/gflownet/algo/advantage_actor_critic.py | 1 + src/gflownet/data/data_source.py | 4 +--- src/gflownet/online_trainer.py | 1 - src/gflownet/tasks/qm9_moo.py | 4 ++-- src/gflownet/tasks/seh_frag_moo.py | 4 ++-- 5 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index c6547b3d..40c58010 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -7,6 +7,7 @@ from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device + from .graph_sampling import GraphSampler diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 4df5dcd0..b68bfc8c 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -257,9 +257,7 @@ def compute_log_rewards(self, trajs): log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) min_r = torch.as_tensor(self.cfg.algo.illegal_action_logreward).float() for i in range(len(trajs)): - trajs[i]["log_reward"] = ( - log_rewards[i] if trajs[i].get("is_valid", True) else min_r - ) + trajs[i]["log_reward"] = log_rewards[i] if trajs[i].get("is_valid", True) else min_r def send_to_replay(self, trajs): if self.replay_buffer is not None: diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index a73802d2..9d30d457 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -6,7 +6,6 @@ import torch from omegaconf import OmegaConf from torch import Tensor -from torch.utils.data import DataLoader from gflownet.algo.advantage_actor_critic import A2C from gflownet.algo.flow_matching import FlowMatching diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index d69e67f9..45c3576b 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -7,14 +7,14 @@ import torch_geometric.data as gd from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset import gflownet.models.mxmnet as mxmnet from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config -from gflownet.data.qm9 import QM9Dataset from gflownet.data.data_source import DataSource +from gflownet.data.qm9 import QM9Dataset from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 9ee69408..60f8092c 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -8,14 +8,14 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty -from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.data.data_source import DataSource from gflownet.data.sampling_iterator import SQLiteLogHook +from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask from gflownet.trainer import FlatRewards, RewardScalar From 279ecfcb8577dd3866d883ff4d1adaf2830f2f51 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 08:25:22 -0700 Subject: [PATCH 05/21] change batch size hyperparameters + fix nested dataclasses --- src/gflownet/algo/config.py | 38 +++++++++++++++++------------ src/gflownet/algo/graph_sampling.py | 3 ++- src/gflownet/config.py | 14 +++++------ src/gflownet/data/config.py | 16 +++++++----- src/gflownet/data/data_source.py | 15 ++++++++---- src/gflownet/models/config.py | 6 ++--- src/gflownet/online_trainer.py | 1 - src/gflownet/tasks/config.py | 8 +++--- src/gflownet/tasks/make_rings.py | 3 +-- src/gflownet/tasks/qm9.py | 3 ++- src/gflownet/tasks/qm9_moo.py | 7 ++---- src/gflownet/tasks/seh_frag.py | 8 +++--- src/gflownet/tasks/seh_frag_moo.py | 11 ++++----- src/gflownet/tasks/toy_seq.py | 4 +-- src/gflownet/trainer.py | 26 ++++++++++---------- src/gflownet/utils/config.py | 8 +++--- 16 files changed, 89 insertions(+), 82 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 0ccf2e0e..e2576982 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Optional @@ -95,8 +95,18 @@ class AlgoConfig: ---------- method : str The name of the algorithm to use (e.g. "TB") - global_batch_size : int - The batch size for training + num_from_policy : int + The number of on-policy samples for a training batch. + If using a replay buffer, see `replay.num_from_replay` for the number of samples from the replay buffer, and + `replay.num_new_samples` for the number of new samples to add to the replay buffer (e.g. `num_from_policy=0`, + and `num_new_samples=N` inserts `N` new samples in the replay buffer at each step, but does not make that data + part of the training batch). + num_from_dataset : int + The number of samples from the dataset for a training batch + valid_num_from_policy : int + The number of on-policy samples for a validation batch + valid_num_from_dataset : int + The number of samples from the dataset for a validation batch max_len : int The maximum length of a trajectory max_nodes : int @@ -105,11 +115,6 @@ class AlgoConfig: The maximum number of edges in a generated graph illegal_action_logreward : float The log reward an agent gets for illegal actions - offline_ratio: float - The ratio of samples drawn from `self.training_data` during training. The rest is drawn from - `self.sampling_model` - valid_offline_ratio: float - Idem but for validation, and `self.test_data`. train_random_action_prob : float The probability of taking a random action during training train_det_after: Optional[int] @@ -121,19 +126,20 @@ class AlgoConfig: """ method: str = "TB" - global_batch_size: int = 64 + num_from_policy: int = 64 + num_from_dataset: int = 0 + valid_num_from_policy: int = 64 + valid_num_from_dataset: int = 0 max_len: int = 128 max_nodes: int = 128 max_edges: int = 128 illegal_action_logreward: float = -100 - offline_ratio: float = 0.5 - valid_offline_ratio: float = 1 train_random_action_prob: float = 0.0 train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 sampling_tau: float = 0.0 - tb: TBConfig = TBConfig() - moql: MOQLConfig = MOQLConfig() - a2c: A2CConfig = A2CConfig() - fm: FMConfig = FMConfig() - sql: SQLConfig = SQLConfig() + tb: TBConfig = field(default_factory=TBConfig) + moql: MOQLConfig = field(default_factory=MOQLConfig) + a2c: A2CConfig = field(default_factory=A2CConfig) + fm: FMConfig = field(default_factory=FMConfig) + sql: SQLConfig = field(default_factory=SQLConfig) diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 7ad4fc0a..0db5bcec 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -1,4 +1,5 @@ import copy +import warnings from typing import List, Optional import torch @@ -248,7 +249,7 @@ def not_done(lst): # TODO: This should be doable. if random_action_prob > 0: - raise NotImplementedError("Random action not implemented for backward sampling") + warnings.warn("Random action not implemented for backward sampling") while sum(done) < n: torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))] diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 73ed6f15..e8238e97 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass, field, fields, is_dataclass from typing import Optional from omegaconf import MISSING @@ -101,12 +101,12 @@ class Config: pickle_mp_messages: bool = False git_hash: Optional[str] = None overwrite_existing_exp: bool = False - algo: AlgoConfig = AlgoConfig() - model: ModelConfig = ModelConfig() - opt: OptimizerConfig = OptimizerConfig() - replay: ReplayConfig = ReplayConfig() - task: TasksConfig = TasksConfig() - cond: ConditionalsConfig = ConditionalsConfig() + algo: AlgoConfig = field(default_factory=AlgoConfig) + model: ModelConfig = field(default_factory=ModelConfig) + opt: OptimizerConfig = field(default_factory=OptimizerConfig) + replay: ReplayConfig = field(default_factory=ReplayConfig) + task: TasksConfig = field(default_factory=TasksConfig) + cond: ConditionalsConfig = field(default_factory=ConditionalsConfig) def init_empty(cfg): diff --git a/src/gflownet/data/config.py b/src/gflownet/data/config.py index 5c5a9c84..ce1bac7e 100644 --- a/src/gflownet/data/config.py +++ b/src/gflownet/data/config.py @@ -16,15 +16,19 @@ class ReplayConfig: The number of samples to collect before starting to sample from the replay buffer hindsight_ratio : float The ratio of hindsight samples within a batch - batch_size : Optional[int] - The batch size for sampling from the replay buffer, defaults to the online batch size - replaces_online_data : bool - Whether to replace online data with samples from the replay buffer + num_from_replay : Optional[int] + The number of replayed samples for a training batch (defaults to cfg.algo.num_from_policy, i.e. a 50/50 split) + num_new_samples : Optional[int] + The number of new samples added to the replay at every training step. Defaults to cfg.algo.num_from_policy. If + smaller than num_from_policy then not all on-policy samples will be added to the replay. If larger + than num_from_policy then the training batch will not contain all the new samples, but the buffer will. + For example, if one wishes to sample N samples every step but only add them to the buffer and not make them + part of the training batch, then one should set replay.num_new_samples=N and algo.num_from_policy=0. """ use: bool = False capacity: Optional[int] = None warmup: Optional[int] = None hindsight_ratio: float = 0 - batch_size: Optional[int] = None - replaces_online_data: bool = True + num_from_replay: Optional[int] = None + num_new_samples: Optional[int] = None diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index b68bfc8c..ca8c6d1b 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -80,9 +80,14 @@ def __iter__(self): batch_info.update(d) yield self.create_batch(trajs, batch_info) - def do_sample_model(self, model, num_samples, keep_samples_in_batch=True): - if not keep_samples_in_batch: - assert self.replay_buffer is not None, "Throwing away samples without a replay buffer" + def do_sample_model(self, model, num_from_policy, num_new_replay_samples=None): + if num_new_replay_samples is not None: + assert self.replay_buffer is not None, "num_new_replay_samples specified without a replay buffer" + if num_new_replay_samples is None: + assert self.replay_buffer is None, "num_new_replay_samples not specified with a replay buffer" + + num_new_replay_samples = num_new_replay_samples or 0 + num_samples = max(num_from_policy, num_new_replay_samples) def iterator(): while self.active: @@ -94,9 +99,9 @@ def iterator(): self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) - self.send_to_replay(trajs) # This is a no-op if there is no replay buffer + self.send_to_replay(trajs[:num_new_replay_samples]) # This is a no-op if there is no replay buffer batch_info = self.call_sampling_hooks(trajs) - yield (trajs, batch_info) if keep_samples_in_batch else ([], {}) + yield (trajs[:num_from_policy], batch_info) self.iterators.append(iterator) return self diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index 05b00b6e..acce656a 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum @@ -35,5 +35,5 @@ class ModelConfig: num_layers: int = 3 num_emb: int = 128 dropout: float = 0 - graph_transformer: GraphTransformerConfig = GraphTransformerConfig() - seq_transformer: SeqTransformerConfig = SeqTransformerConfig() + graph_transformer: GraphTransformerConfig = field(default_factory=GraphTransformerConfig) + seq_transformer: SeqTransformerConfig = field(default_factory=SeqTransformerConfig) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 9d30d457..3ba1fde1 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -95,7 +95,6 @@ def setup(self): else: self.sampling_model = self.model - self.mb_size = self.cfg.algo.global_batch_size self.clip_grad_callback = { "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), "norm": lambda params: [torch.nn.utils.clip_grad_norm_(p, self.cfg.opt.clip_grad_param) for p in params], diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 4c29f634..3c8a0fab 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -62,7 +62,7 @@ class QM9MOOTaskConfig: @dataclass class TasksConfig: - qm9: QM9TaskConfig = QM9TaskConfig() - qm9_moo: QM9MOOTaskConfig = QM9MOOTaskConfig() - seh: SEHTaskConfig = SEHTaskConfig() - seh_moo: SEHMOOTaskConfig = SEHMOOTaskConfig() + qm9: QM9TaskConfig = field(default_factory=QM9TaskConfig) + qm9_moo: QM9MOOTaskConfig = field(default_factory=QM9MOOTaskConfig) + seh: SEHTaskConfig = field(default_factory=SEHTaskConfig) + seh_moo: SEHMOOTaskConfig = field(default_factory=SEHMOOTaskConfig) diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index 9211d038..00b2863b 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -41,8 +41,7 @@ class MakeRingsTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.num_workers = 8 - cfg.algo.global_batch_size = 64 - cfg.algo.offline_ratio = 0 + cfg.algo.num_from_policy = 64 cfg.model.num_emb = 128 cfg.model.num_layers = 4 diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 0e934906..4a3dd1ba 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -120,7 +120,8 @@ def set_default_hps(self, cfg: Config): cfg.opt.clip_grad_type = "norm" cfg.opt.clip_grad_param = 10 cfg.algo.max_nodes = 9 - cfg.algo.global_batch_size = 64 + cfg.algo.num_from_policy = 32 + cfg.algo.num_from_dataset = 32 cfg.algo.train_random_action_prob = 0.001 cfg.algo.illegal_action_logreward = -75 cfg.algo.sampling_tau = 0.0 diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 45c3576b..ec62643c 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -198,10 +198,8 @@ class QM9MOOTrainer(QM9GapTrainer): def set_default_hps(self, cfg: Config): super().set_default_hps(cfg) cfg.algo.sampling_tau = 0.95 - # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) - # sampling and set the offline ratio to 1 cfg.cond.valid_sample_cond_info = False - cfg.algo.valid_offline_ratio = 1 + cfg.algo.valid_num_from_dataset = 64 def setup_algo(self): algo = self.cfg.algo.method @@ -330,9 +328,8 @@ def on_validation_end(self, metrics: Dict[str, Any]): def build_validation_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) - n_from_dataset = self.cfg.algo.global_batch_size src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) - src.do_conditionals_dataset_in_order(self.test_data, n_from_dataset, model) + src.do_conditionals_dataset_in_order(self.test_data, self.cfg.algo.valid_num_from_dataset, model) if self.cfg.log_dir: src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 54adcc7c..5c9040a7 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -111,7 +111,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class LittleSEHDataset(Dataset): """Note: this dataset isn't used by default, but turning it on showcases some features of this codebase. - To turn on, self `cfg.algo.offline_ratio > 0`""" + To turn on, self `cfg.algo.num_from_dataset > 0`""" def __init__(self, smis) -> None: super().__init__() @@ -146,8 +146,7 @@ def set_default_hps(self, cfg: Config): cfg.opt.lr_decay = 20_000 cfg.opt.clip_grad_type = "norm" cfg.opt.clip_grad_param = 10 - cfg.algo.global_batch_size = 64 - cfg.algo.offline_ratio = 0 + cfg.algo.num_from_policy = 64 cfg.model.num_emb = 128 cfg.model.num_layers = 4 @@ -157,7 +156,7 @@ def set_default_hps(self, cfg: Config): cfg.algo.illegal_action_logreward = -75 cfg.algo.train_random_action_prob = 0.0 cfg.algo.valid_random_action_prob = 0.0 - cfg.algo.valid_offline_ratio = 0 + cfg.algo.valid_num_from_policy = 64 cfg.num_validation_gen_steps = 10 cfg.algo.tb.epsilon = None cfg.algo.tb.bootstrap_own_reward = False @@ -213,7 +212,6 @@ def main(): config.num_workers = 8 config.opt.lr_decay = 20_000 config.algo.sampling_tau = 0.99 - config.algo.offline_ratio = 0.0 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 60f8092c..ca5b7e6a 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -222,10 +222,10 @@ class SEHMOOFragTrainer(SEHFragTrainer): def set_default_hps(self, cfg: Config): super().set_default_hps(cfg) cfg.algo.sampling_tau = 0.95 - # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) - # sampling and set the offline ratio to 1 - cfg.cond.valid_sample_cond_info = False - cfg.algo.valid_offline_ratio = 1 + # We sample from a dataset of valid conditional information, so we set this, and override + # build_validation_data_loader to use the dataset + cfg.cond.valid_sample_cond_info = False # TODO deprecate this? + cfg.algo.valid_num_from_dataset = 64 def setup_algo(self): algo = self.cfg.algo.method @@ -356,9 +356,8 @@ def on_validation_end(self, metrics: Dict[str, Any]): def build_validation_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) - n_from_dataset = self.cfg.algo.global_batch_size src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) - src.do_conditionals_dataset_in_order(self.test_data, n_from_dataset, model) + src.do_conditionals_dataset_in_order(self.test_data, self.cfg.algo.valid_num_from_dataset, model) if self.cfg.log_dir: src.add_sampling_hook(SQLiteLogHook(str(pathlib.Path(self.cfg.log_dir) / "valid"), self.ctx)) diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 901baea6..2aece1b2 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -54,8 +54,7 @@ def set_default_hps(self, cfg: Config): cfg.opt.lr_decay = 20_000 cfg.opt.clip_grad_type = "norm" cfg.opt.clip_grad_param = 10 - cfg.algo.global_batch_size = 64 - cfg.algo.offline_ratio = 0 + cfg.algo.num_from_policy = 64 cfg.model.num_emb = 64 cfg.model.num_layers = 4 @@ -66,7 +65,6 @@ def set_default_hps(self, cfg: Config): cfg.algo.illegal_action_logreward = -75 cfg.algo.train_random_action_prob = 0.0 cfg.algo.valid_random_action_prob = 0.0 - cfg.algo.valid_offline_ratio = 0 cfg.algo.tb.epsilon = None cfg.algo.tb.bootstrap_own_reward = False cfg.algo.tb.Z_learning_rate = 1e-2 diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index a9409562..5caed451 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -133,7 +133,6 @@ def __init__(self, config: Config, print_config=True): # the same as `model`. self.sampling_model: nn.Module self.replay_buffer: Optional[ReplayBuffer] - self.mb_size: int self.env: GraphBuildingEnv self.ctx: GraphBuildingEnvContext self.task: GFNTask @@ -242,18 +241,21 @@ def build_training_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) - n_drawn = int(self.cfg.algo.global_batch_size * (1 - self.cfg.algo.offline_ratio)) - n_replayed = n_drawn if self.cfg.replay.batch_size is None else self.cfg.replay.batch_size - n_from_dataset = self.cfg.algo.global_batch_size - n_drawn + if self.cfg.replay.use: + # None is fine for either value, it will be replaced by num_from_policy, but 0 is not + assert self.cfg.replay.num_from_replay != 0, "Replay is enabled but no samples are being drawn from it" + assert self.cfg.replay.num_new_samples != 0, "Replay is enabled but no new samples are being added to it" + + n_drawn = self.cfg.algo.num_from_policy + n_replayed = self.cfg.replay.num_from_replay or n_drawn if self.cfg.replay.use else 0 + n_new_replay_samples = self.cfg.replay.num_new_samples or n_drawn if self.cfg.replay.use else None + n_from_dataset = self.cfg.algo.num_from_dataset src = DataSource(self.cfg, self.ctx, self.algo, self.task, replay_buffer=replay_buffer) if n_from_dataset: src.do_sample_dataset(self.training_data, n_from_dataset, backwards_model=model) if n_drawn: - # If we are using a replay buffer, we can choose to keep the new samples in the minibatch, or just - # send them to the replay and train only on replay samples. - keep_samples_in_batch = not self.cfg.replay.use or not self.cfg.replay.replaces_online_data - src.do_sample_model(model, n_drawn, keep_samples_in_batch) + src.do_sample_model(model, n_drawn, n_new_replay_samples) if n_replayed and replay_buffer is not None: src.do_sample_replay(n_replayed) if self.cfg.log_dir: @@ -266,8 +268,8 @@ def build_validation_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) # TODO: we're changing the default, make sure anything that is using test data is adjusted src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) - n_drawn = int(self.cfg.algo.global_batch_size * (1 - self.cfg.algo.valid_offline_ratio)) - n_from_dataset = self.cfg.algo.global_batch_size - n_drawn + n_drawn = self.cfg.algo.valid_num_from_policy + n_from_dataset = self.cfg.algo.valid_num_from_dataset src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) if n_from_dataset: @@ -285,10 +287,8 @@ def build_validation_data_loader(self) -> DataLoader: def build_final_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.model, send_to_device=True) - # TODO: we're changing the default, make sure anything that is using test data is adjusted - src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) - n_drawn = int(self.cfg.algo.global_batch_size) + n_drawn = self.cfg.algo.num_from_policy src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) assert self.cfg.num_final_gen_steps is not None # TODO: might be better to change total steps to total trajectories drawn diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index 8f67af3a..5ee5369a 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -73,7 +73,7 @@ class FocusRegionConfig: @dataclass class ConditionalsConfig: valid_sample_cond_info: bool = True - temperature: TempCondConfig = TempCondConfig() - moo: MultiObjectiveConfig = MultiObjectiveConfig() - weighted_prefs: WeightedPreferencesConfig = WeightedPreferencesConfig() - focus_region: FocusRegionConfig = FocusRegionConfig() + temperature: TempCondConfig = field(default_factory=TempCondConfig) + moo: MultiObjectiveConfig = field(default_factory=MultiObjectiveConfig) + weighted_prefs: WeightedPreferencesConfig = field(default_factory=WeightedPreferencesConfig) + focus_region: FocusRegionConfig = field(default_factory=FocusRegionConfig) From 282bbfb82f3cfd444ba1dc720da261aa292a4d84 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 08:47:06 -0700 Subject: [PATCH 06/21] move things around & prevent circular import --- src/gflownet/__init__.py | 88 +++++ src/gflownet/data/data_source.py | 14 +- src/gflownet/data/sampling_iterator.py | 489 ------------------------- src/gflownet/tasks/make_rings.py | 2 +- src/gflownet/tasks/qm9.py | 2 +- src/gflownet/tasks/qm9_moo.py | 2 +- src/gflownet/tasks/seh_frag.py | 2 +- src/gflownet/tasks/seh_frag_moo.py | 2 +- src/gflownet/tasks/toy_seq.py | 2 +- src/gflownet/trainer.py | 85 +---- src/gflownet/utils/sqlite_log.py | 93 +++++ 11 files changed, 196 insertions(+), 585 deletions(-) delete mode 100644 src/gflownet/data/sampling_iterator.py create mode 100644 src/gflownet/utils/sqlite_log.py diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index e69de29b..9f445e01 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -0,0 +1,88 @@ +from typing import Dict, List, NewType, Optional, Tuple + +import torch_geometric.data as gd +from rdkit.Chem.rdchem import Mol as RDMol +from torch import Tensor, nn + +from .config import Config + +# This type represents an unprocessed list of reward signals/conditioning information +FlatRewards = NewType("FlatRewards", Tensor) # type: ignore + +# This type represents the outcome for a multi-objective task of +# converting FlatRewards to a scalar, e.g. (sum R_i omega_i) ** beta +RewardScalar = NewType("RewardScalar", Tensor) # type: ignore + + +class GFNAlgorithm: + updates: int = 0 + global_cfg: Config + is_eval: bool = False + + def step(self): + self.updates += 1 # This isn't used anywhere? + + def compute_batch_losses( + self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Computes the loss for a batch of data, and proves logging informations + + Parameters + ---------- + model: nn.Module + The model being trained or evaluated + batch: gd.Batch + A batch of graphs + num_bootstrap: Optional[int] + The number of trajectories with reward targets in the batch (if applicable). + + Returns + ------- + loss: Tensor + The loss for that batch + info: Dict[str, Tensor] + Logged information about model predictions. + """ + raise NotImplementedError() + + def get_random_action_prob(self, it: int): + if self.is_eval: + return self.global_cfg.algo.valid_random_action_prob + if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after: + return self.global_cfg.algo.train_random_action_prob + return 0 + + +class GFNTask: + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + """Combines a minibatch of reward signal vectors and conditional information into a scalar reward. + + Parameters + ---------- + cond_info: Dict[str, Tensor] + A dictionary with various conditional informations (e.g. temperature) + flat_reward: FlatRewards + A 2d tensor where each row represents a series of flat rewards. + + Returns + ------- + reward: RewardScalar + A 1d tensor, a scalar log-reward for each minibatch entry. + """ + raise NotImplementedError() + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + """Compute the flat rewards of mols according the the tasks' proxies + + Parameters + ---------- + mols: List[RDMol] + A list of RDKit molecules. + Returns + ------- + reward: FlatRewards + A 2d tensor, a vector of scalar reward for valid each molecule. + is_valid: Tensor + A 1d tensor, a boolean indicating whether the molecule is valid. + """ + raise NotImplementedError() diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index ca8c6d1b..1f74dc59 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -5,13 +5,12 @@ import torch from torch.utils.data import IterableDataset +from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphBuildingEnvContext from gflownet.utils.misc import get_worker_rng -# from gflownet.trainer import GFNAlgorithm, GFNTask - def cycle_call(it): while True: @@ -24,8 +23,8 @@ def __init__( self, cfg: Config, ctx: GraphBuildingEnvContext, - algo, #: GFNAlgorithm, - task, #: GFNTask, # TODO: this will cause a circular import + algo: GFNAlgorithm, + task: GFNTask, replay_buffer: Optional[ReplayBuffer] = None, is_algo_eval: bool = False, start_at_step: int = 0, @@ -230,7 +229,7 @@ def create_batch(self, trajs, batch_info): log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) - # TODO: find code that depends on batch.flat_rewards and deprecate it + batch.flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) return batch def compute_properties(self, trajs, mark_as_online=False): @@ -291,8 +290,9 @@ def relabel_in_hindsight(self, trajs): cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( cond_info, log_rewards, flat_rewards, hindsight_idxs ) - # TODO: This seems wrong, since we haven't recomputed is_valid - # log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + self.set_traj_cond_info(trajs, cond_info) + for i in range(len(trajs)): + trajs[i]["log_reward"] = log_rewards[i] def sample_idcs(self, n, num_samples): return self.rng.choice(n, num_samples, replace=False) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py deleted file mode 100644 index 4d2a8c07..00000000 --- a/src/gflownet/data/sampling_iterator.py +++ /dev/null @@ -1,489 +0,0 @@ -import os -import sqlite3 -from collections.abc import Iterable -from typing import Callable, List, Optional - -import numpy as np -import torch -import torch.nn as nn -from rdkit import RDLogger -from torch.utils.data import Dataset, IterableDataset - -from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphActionCategorical - - -class SamplingIterator(IterableDataset): - """This class allows us to parallelise and train faster. - - By separating sampling data/the model and building torch geometric - graphs from training the model, we can do the former in different - processes, which is much faster since much of graph construction - is CPU-bound. - - """ - - def __init__( - self, - dataset: Dataset, - model: nn.Module, - ctx, - algo, - task, - device, - batch_size: int = 1, - illegal_action_logreward: float = -50, - ratio: float = 0.5, - stream: bool = True, - replay_buffer: ReplayBuffer = None, - log_dir: str = None, - sample_cond_info: bool = True, - random_action_prob: float = 0.0, - det_after: Optional[int] = None, - hindsight_ratio: float = 0.0, - init_train_iter: int = 0, - ): - """Parameters - ---------- - dataset: Dataset - A dataset instance - model: nn.Module - The model we sample from (must be on CUDA already or share_memory() must be called so that - parameters are synchronized between each worker) - ctx: - The context for the environment, e.g. a MolBuildingEnvContext instance - algo: - The training algorithm, e.g. a TrajectoryBalance instance - task: GFNTask - A Task instance, e.g. a MakeRingsTask instance - device: torch.device - The device the model is on - replay_buffer: ReplayBuffer - The replay buffer for training on past data - batch_size: int - The number of trajectories, each trajectory will be comprised of many graphs, so this is - _not_ the batch size in terms of the number of graphs (that will depend on the task) - illegal_action_logreward: float - The logreward for invalid trajectories - ratio: float - The ratio of offline trajectories in the batch. - stream: bool - If True, data is sampled iid for every batch. Otherwise, this is a normal in-order - dataset iterator. - log_dir: str - If not None, logs each SamplingIterator worker's generated molecules to that file. - sample_cond_info: bool - If True (default), then the dataset is a dataset of points used in offline training. - If False, then the dataset is a dataset of preferences (e.g. used to validate the model) - random_action_prob: float - The probability of taking a random action, passed to the graph sampler - init_train_iter: int - The initial training iteration, incremented and passed to task.sample_conditional_information - """ - self.data = dataset - self.model = model - self.replay_buffer = replay_buffer - self.batch_size = batch_size - self.illegal_action_logreward = illegal_action_logreward - self.offline_batch_size = int(np.ceil(self.batch_size * ratio)) - self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio))) - self.ratio = ratio - self.ctx = ctx - self.algo = algo - self.task = task - self.device = device - self.stream = stream - self.sample_online_once = True # TODO: deprecate this, disallow len(data) == 0 entirely - self.sample_cond_info = sample_cond_info - self.random_action_prob = random_action_prob - self.hindsight_ratio = hindsight_ratio - self.train_it = init_train_iter - self.do_validate_batch = False # Turn this on for debugging - self.iter = 0 - self.det_after = det_after - # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) - # then "offline" now refers to cond info and online to x, so no duplication and we don't end - # up with 2*batch_size accidentally - if not sample_cond_info: - self.offline_batch_size = self.online_batch_size = self.batch_size - - # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we - # don't want to initialize per-worker things just yet, such as where the log the worker writes - # to. This must be done in __iter__, which is called by the DataLoader once this instance - # has been copied into a new python process. - self.log_dir = log_dir - self.log = SQLiteLog() - self.log_hooks: List[Callable] = [] - - def add_log_hook(self, hook: Callable): - self.log_hooks.append(hook) - - def _idx_iterator(self): - RDLogger.DisableLog("rdApp.*") - if self.stream: - # If we're streaming data, just sample `offline_batch_size` indices - while True: - if self.offline_batch_size == 0 or len(self.data) == 0: - yield np.arange(0, 0) - else: - yield self.rng.integers(0, len(self.data), self.offline_batch_size) - else: - # Otherwise, figure out which indices correspond to this worker - worker_info = torch.utils.data.get_worker_info() - n = len(self.data) - if n == 0: - yield np.arange(0, 0) - return - assert ( - self.offline_batch_size > 0 - ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" - if worker_info is None: # no multi-processing - start, end, wid = 0, n, -1 - else: # split the data into chunks (per-worker) - nw = worker_info.num_workers - wid = worker_info.id - start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) - bs = self.offline_batch_size - if end - start <= bs: - yield np.arange(start, end) - return - for i in range(start, end - bs, bs): - yield np.arange(i, i + bs) - if i + bs < end: - yield np.arange(i + bs, end) - - def __len__(self): - if self.stream: - return int(1e6) - if len(self.data) == 0 and self.sample_online_once: - return 1 - return len(self.data) - - def __iter__(self): - self.iter += 1 - if self.det_after is not None and self.iter > self.det_after: - self.random_action_prob = 0 - worker_info = torch.utils.data.get_worker_info() - self._wid = worker_info.id if worker_info is not None else 0 - # Now that we know we are in a worker instance, we can initialize per-worker things - self.rng = self.algo.rng = self.task.rng = np.random.default_rng(142857 + self._wid) - self.ctx.device = self.device - if self.log_dir is not None: - os.makedirs(self.log_dir, exist_ok=True) - self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" - self.log.connect(self.log_path) - - for idcs in self._idx_iterator(): - num_offline = idcs.shape[0] # This is in [0, self.offline_batch_size] - # Sample conditional info such as temperature, trade-off weights, etc. - - if self.sample_cond_info: - num_online = self.online_batch_size - cond_info = self.task.sample_conditional_information( - num_offline + self.online_batch_size, self.train_it - ) - - # Sample some dataset data - graphs, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], []) - flat_rewards = ( - list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] - ) - - trajs = self.algo.create_training_data_from_graphs( - graphs, self.model, cond_info["encoding"][:num_offline], 0 - ) - - else: # If we're not sampling the conditionals, then the idcs refer to listed preferences - num_online = num_offline - num_offline = 0 - cond_info = self.task.encode_conditional_information( - steer_info=torch.stack([self.data[i] for i in idcs]) # This is sus, what's going on here? - ) - trajs, flat_rewards = [], [] - - # Sample some on-policy data - is_valid = torch.ones(num_offline + num_online).bool() - if num_online > 0: - with torch.no_grad(): - trajs += self.algo.create_training_data_from_own_samples( - self.model, - num_online, - cond_info["encoding"][num_offline:], - random_action_prob=self.random_action_prob, - ) - if self.algo.bootstrap_own_reward: - # The model can be trained to predict its own reward, - # i.e. predict the output of cond_info_to_logreward - pred_reward = [i["reward_pred"].cpu().item() for i in trajs[num_offline:]] - flat_rewards += pred_reward - else: - # Otherwise, query the task for flat rewards - valid_idcs = torch.tensor( - [i + num_offline for i in range(num_online) if trajs[i + num_offline]["is_valid"]] - ).long() - # fetch the valid trajectories endpoints - mols = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] - # ask the task to compute their reward - online_flat_rew, m_is_valid = self.task.compute_flat_rewards(mols) - assert ( - online_flat_rew.ndim == 2 - ), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" - # The task may decide some of the mols are invalid, we have to again filter those - valid_idcs = valid_idcs[m_is_valid] - pred_reward = torch.zeros((num_online, online_flat_rew.shape[1])) - pred_reward[valid_idcs - num_offline] = online_flat_rew - is_valid[num_offline:] = False - is_valid[valid_idcs] = True - flat_rewards += list(pred_reward) - # Override the is_valid key in case the task made some mols invalid - for i in range(num_online): - trajs[num_offline + i]["is_valid"] = is_valid[num_offline + i].item() - - # Compute scalar rewards from conditional information & flat rewards - flat_rewards = torch.stack(flat_rewards) - log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward - - assert len(trajs) == num_online + num_offline - # Computes some metrics - extra_info = {"random_action_prob": self.random_action_prob} - if num_online > 0: - H = sum(i["fwd_logprob"] for i in trajs[num_offline:]) - extra_info["entropy"] = -H / num_online - extra_info["length"] = np.mean([len(i["traj"]) for i in trajs[num_offline:]]) - if not self.sample_cond_info: - # If we're using a dataset of preferences, the user may want to know the id of the preference - for i, j in zip(trajs, idcs): - i["data_idx"] = j - # note: we convert back into natural rewards for logging purposes - # (allows to take averages and plot in objective space) - # TODO: implement that per-task (in case they don't apply the same beta and log transformations) - rewards = torch.exp(log_rewards / cond_info["beta"]) - if num_online > 0 and self.log_dir is not None: - self.log_generated( - trajs[num_offline:], - rewards[num_offline:], - flat_rewards[num_offline:], - {k: v[num_offline:] for k, v in cond_info.items()}, - ) - if num_online > 0: - extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item() - for hook in self.log_hooks: - extra_info.update( - hook( - trajs[num_offline:], - rewards[num_offline:], - flat_rewards[num_offline:], - {k: v[num_offline:] for k, v in cond_info.items()}, - ) - ) - - if self.replay_buffer is not None: - # If we have a replay buffer, we push the online trajectories in it - # and resample immediately such that the "online" data in the batch - # comes from a more stable distribution (try to avoid forgetting) - - # cond_info is a dict, so we need to convert it to a list of dicts - cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)] - - # push the online trajectories in the replay buffer and sample a new 'online' batch - for i in range(num_offline, len(trajs)): - self.replay_buffer.push( - trajs[i], - log_rewards[i], - flat_rewards[i], - cond_info[i], - is_valid[i], - ) - replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample( - num_online - ) - - # append the online trajectories to the offline ones - trajs = trajs[:num_offline] + replay_trajs - log_rewards = torch.cat([log_rewards[:num_offline], replay_logr], dim=0) - flat_rewards = torch.cat([flat_rewards[:num_offline], replay_fr], dim=0) - cond_info = cond_info[:num_offline] + replay_condinfo # list of dicts - is_valid = torch.cat([is_valid[:num_offline], replay_valid], dim=0) - - # convert cond_info back to a dict - cond_info = {k: torch.stack([d[k] for d in cond_info]) for k in cond_info[0]} - - if self.hindsight_ratio > 0.0: - # Relabels some of the online trajectories with hindsight - assert hasattr( - self.task, "relabel_condinfo_and_logrewards" - ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" - # samples indexes of trajectories without repeats - hindsight_idxs = torch.randperm(num_online)[: int(num_online * self.hindsight_ratio)] + num_offline - cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( - cond_info, log_rewards, flat_rewards, hindsight_idxs - ) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward - - # Construct batch - batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) - batch.num_offline = num_offline - batch.num_online = num_online - batch.flat_rewards = flat_rewards - batch.preferences = cond_info.get("preferences", None) - batch.focus_dir = cond_info.get("focus_dir", None) - batch.extra_info = extra_info - if self.ctx.has_n(): - log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] - batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) - batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) - # TODO: we could very well just pass the cond_info dict to construct_batch above, - # and the algo can decide what it wants to put in the batch object - - # Only activate for debugging your environment or dataset (e.g. the dataset could be - # generating trajectories with illegal actions) - if self.do_validate_batch: - self.validate_batch(batch, trajs) - - self.train_it += worker_info.num_workers if worker_info is not None else 1 - yield batch - - def validate_batch(self, batch, trajs): - for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( - [(batch.bck_actions, self.ctx.bck_action_type_order)] - if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") - else [] - ): - mask_cat = GraphActionCategorical( - batch, - [self.model._action_type_to_mask(t, batch) for t in atypes], - [self.model._action_type_to_key[t] for t in atypes], - [None for _ in atypes], - ) - masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) - num_trajs = len(trajs) - batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) - first_graph_idx = torch.zeros_like(batch.traj_lens) - torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) - if masked_action_is_used.sum() != 0: - invalid_idx = masked_action_is_used.argmax().item() - traj_idx = batch_idx[invalid_idx].item() - timestep = invalid_idx - first_graph_idx[traj_idx].item() - raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) - - def log_generated(self, trajs, rewards, flat_rewards, cond_info): - if hasattr(self.ctx, "object_to_log_repr"): - mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] - else: - mols = [""] * len(trajs) - - flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() - rewards = rewards.data.numpy().tolist() - preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() - logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] - - data = [ - [mols[i], rewards[i]] - + flat_rewards[i] - + preferences[i] - + focus_dir[i] - + [cond_info[k][i].item() for k in logged_keys] - for i in range(len(trajs)) - ] - - data_labels = ( - ["smi", "r"] - + [f"fr_{i}" for i in range(len(flat_rewards[0]))] - + [f"pref_{i}" for i in range(len(preferences[0]))] - + [f"focus_{i}" for i in range(len(focus_dir[0]))] - + [f"ci_{k}" for k in logged_keys] - ) - - self.log.insert_many(data, data_labels) - - -class SQLiteLogHook: - def __init__(self, log_dir, ctx) -> None: - self.log = None # Only initialized in __call__, which will occur inside the worker - self.log_dir = log_dir - self.ctx = ctx - self.data_labels = None - - def __call__(self, trajs, rewards, flat_rewards, cond_info): - if self.log is None: - worker_info = torch.utils.data.get_worker_info() - self._wid = worker_info.id if worker_info is not None else 0 - os.makedirs(self.log_dir, exist_ok=True) - self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" - self.log = SQLiteLog() - self.log.connect(self.log_path) - - if hasattr(self.ctx, "object_to_log_repr"): - mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] - else: - mols = [""] * len(trajs) - - flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() - rewards = rewards.data.numpy().tolist() - preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() - logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] - - data = [ - [mols[i], rewards[i]] - + flat_rewards[i] - + preferences[i] - + focus_dir[i] - + [cond_info[k][i].item() for k in logged_keys] - for i in range(len(trajs)) - ] - if self.data_labels is None: - self.data_labels = ( - ["smi", "r"] - + [f"fr_{i}" for i in range(len(flat_rewards[0]))] - + [f"pref_{i}" for i in range(len(preferences[0]))] - + [f"focus_{i}" for i in range(len(focus_dir[0]))] - + [f"ci_{k}" for k in logged_keys] - ) - - self.log.insert_many(data, self.data_labels) - return {} - - -class SQLiteLog: - def __init__(self, timeout=300): - """Creates a log instance, but does not connect it to any db.""" - self.is_connected = False - self.db = None - self.timeout = timeout - - def connect(self, db_path: str): - """Connects to db_path - - Parameters - ---------- - db_path: str - The sqlite3 database path. If it does not exist, it will be created. - """ - self.db = sqlite3.connect(db_path, timeout=self.timeout) - cur = self.db.cursor() - self._has_results_table = len( - cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() - ) - cur.close() - - def _make_results_table(self, types, names): - type_map = {str: "text", float: "real", int: "real"} - col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) - cur = self.db.cursor() - cur.execute(f"create table results ({col_str})") - self._has_results_table = True - cur.close() - - def insert_many(self, rows, column_names): - assert all( - [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] - ), "rows must only contain scalars" - if not self._has_results_table: - self._make_results_table([type(i) for i in rows[0]], column_names) - cur = self.db.cursor() - cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec - cur.close() - self.db.commit() diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index 00b2863b..34f47924 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -7,10 +7,10 @@ from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor +from gflownet import FlatRewards, GFNTask, RewardScalar from gflownet.config import Config, init_empty from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar class MakeRingsTask(GFNTask): diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 4a3dd1ba..0bad429f 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -9,11 +9,11 @@ from torch.utils.data import Dataset import gflownet.models.mxmnet as mxmnet +from gflownet import FlatRewards, GFNTask, RewardScalar from gflownet.config import Config, init_empty from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index ec62643c..4ca55483 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -10,6 +10,7 @@ from torch.utils.data import DataLoader, Dataset import gflownet.models.mxmnet as mxmnet +from gflownet import FlatRewards, RewardScalar from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config @@ -19,7 +20,6 @@ from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks -from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 5c9040a7..c24a8d05 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -11,11 +11,11 @@ from torch.utils.data import Dataset from torch_geometric.data import Data +from gflownet import FlatRewards, GFNTask, RewardScalar from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph from gflownet.models import bengio2021flow from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index ca3fa0bc..8ec06d5e 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -10,6 +10,7 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from gflownet import FlatRewards, RewardScalar from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty @@ -18,7 +19,6 @@ from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask -from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index 2aece1b2..f2c75f60 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -5,11 +5,11 @@ import torch from torch import Tensor +from gflownet import FlatRewards, GFNTask, RewardScalar from gflownet.config import Config, init_empty from gflownet.envs.seq_building_env import AutoregressiveSeqBuildingContext, SeqBuildingEnv from gflownet.models.seq_transformer import SeqTransformerGFN from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar from gflownet.utils.conditioning import TemperatureConditional from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 5caed451..92d45538 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -3,7 +3,7 @@ import pathlib import shutil import time -from typing import Any, Callable, Dict, List, NewType, Optional, Protocol, Tuple +from typing import Any, Callable, Dict, List, Optional, Protocol import numpy as np import torch @@ -13,10 +13,10 @@ import wandb from omegaconf import OmegaConf from rdkit import RDLogger -from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import DataLoader, Dataset +from gflownet import GFNAlgorithm, GFNTask from gflownet.data.data_source import DataSource from gflownet.data.replay_buffer import ReplayBuffer from gflownet.data.sampling_iterator import SQLiteLogHook @@ -27,87 +27,6 @@ from .config import Config -# This type represents an unprocessed list of reward signals/conditioning information -FlatRewards = NewType("FlatRewards", Tensor) # type: ignore - -# This type represents the outcome for a multi-objective task of -# converting FlatRewards to a scalar, e.g. (sum R_i omega_i) ** beta -RewardScalar = NewType("RewardScalar", Tensor) # type: ignore - - -class GFNAlgorithm: - updates: int = 0 - global_cfg: Config - is_eval: bool = False - - def step(self): - self.updates += 1 # This isn't used anywhere? - - def compute_batch_losses( - self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 - ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Computes the loss for a batch of data, and proves logging informations - - Parameters - ---------- - model: nn.Module - The model being trained or evaluated - batch: gd.Batch - A batch of graphs - num_bootstrap: Optional[int] - The number of trajectories with reward targets in the batch (if applicable). - - Returns - ------- - loss: Tensor - The loss for that batch - info: Dict[str, Tensor] - Logged information about model predictions. - """ - raise NotImplementedError() - - def get_random_action_prob(self, it: int): - if self.is_eval: - return self.global_cfg.algo.valid_random_action_prob - if self.global_cfg.algo.train_det_after is None or it < self.global_cfg.algo.train_det_after: - return self.global_cfg.algo.train_random_action_prob - return 0 - - -class GFNTask: - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - """Combines a minibatch of reward signal vectors and conditional information into a scalar reward. - - Parameters - ---------- - cond_info: Dict[str, Tensor] - A dictionary with various conditional informations (e.g. temperature) - flat_reward: FlatRewards - A 2d tensor where each row represents a series of flat rewards. - - Returns - ------- - reward: RewardScalar - A 1d tensor, a scalar log-reward for each minibatch entry. - """ - raise NotImplementedError() - - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: - """Compute the flat rewards of mols according the the tasks' proxies - - Parameters - ---------- - mols: List[RDMol] - A list of RDKit molecules. - Returns - ------- - reward: FlatRewards - A 2d tensor, a vector of scalar reward for valid each molecule. - is_valid: Tensor - A 1d tensor, a boolean indicating whether the molecule is valid. - """ - raise NotImplementedError() - class Closable(Protocol): def close(self): diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py new file mode 100644 index 00000000..0740baf8 --- /dev/null +++ b/src/gflownet/utils/sqlite_log.py @@ -0,0 +1,93 @@ +from typing import Iterable +import os +import sqlite3 +import torch + +class SQLiteLogHook: + def __init__(self, log_dir, ctx) -> None: + self.log = None # Only initialized in __call__, which will occur inside the worker + self.log_dir = log_dir + self.ctx = ctx + self.data_labels = None + + def __call__(self, trajs, rewards, flat_rewards, cond_info): + if self.log is None: + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + os.makedirs(self.log_dir, exist_ok=True) + self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" + self.log = SQLiteLog() + self.log.connect(self.log_path) + + if hasattr(self.ctx, "object_to_log_repr"): + mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] + else: + mols = [""] * len(trajs) + + flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() + rewards = rewards.data.numpy().tolist() + preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] + + data = [ + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] + for i in range(len(trajs)) + ] + if self.data_labels is None: + self.data_labels = ( + ["smi", "r"] + + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + + [f"ci_{k}" for k in logged_keys] + ) + + self.log.insert_many(data, self.data_labels) + return {} + + +class SQLiteLog: + def __init__(self, timeout=300): + """Creates a log instance, but does not connect it to any db.""" + self.is_connected = False + self.db = None + self.timeout = timeout + + def connect(self, db_path: str): + """Connects to db_path + + Parameters + ---------- + db_path: str + The sqlite3 database path. If it does not exist, it will be created. + """ + self.db = sqlite3.connect(db_path, timeout=self.timeout) + cur = self.db.cursor() + self._has_results_table = len( + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() + ) + cur.close() + + def _make_results_table(self, types, names): + type_map = {str: "text", float: "real", int: "real"} + col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) + cur = self.db.cursor() + cur.execute(f"create table results ({col_str})") + self._has_results_table = True + cur.close() + + def insert_many(self, rows, column_names): + assert all( + [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] + ), "rows must only contain scalars" + if not self._has_results_table: + self._make_results_table([type(i) for i in rows[0]], column_names) + cur = self.db.cursor() + cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec + cur.close() + self.db.commit() From c3bc6d05087195aab4df6938a8e3105cb6ce2d66 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 08:47:56 -0700 Subject: [PATCH 07/21] tox --- src/gflownet/utils/sqlite_log.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py index 0740baf8..1ac183db 100644 --- a/src/gflownet/utils/sqlite_log.py +++ b/src/gflownet/utils/sqlite_log.py @@ -1,8 +1,10 @@ -from typing import Iterable import os import sqlite3 +from typing import Iterable + import torch + class SQLiteLogHook: def __init__(self, log_dir, ctx) -> None: self.log = None # Only initialized in __call__, which will occur inside the worker From b1c5630019d569243a6987b33891cda322842772 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 08:58:32 -0700 Subject: [PATCH 08/21] fix imports --- src/gflownet/tasks/qm9_moo.py | 2 +- src/gflownet/tasks/seh_frag_moo.py | 2 +- src/gflownet/trainer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 4ca55483..953c48e4 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -16,13 +16,13 @@ from gflownet.config import Config from gflownet.data.data_source import DataSource from gflownet.data.qm9 import QM9Dataset -from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9 import QM9GapTask, QM9GapTrainer from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks from gflownet.utils import metrics from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.utils.sqlite_log import SQLiteLogHook from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 8ec06d5e..bbab91b7 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -15,13 +15,13 @@ from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty from gflownet.data.data_source import DataSource -from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.models import bengio2021flow from gflownet.tasks.seh_frag import SEHFragTrainer, SEHTask from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.utils.sqlite_log import SQLiteLogHook from gflownet.utils.transforms import to_logreward diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 92d45538..f7406fd7 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -19,11 +19,11 @@ from gflownet import GFNAlgorithm, GFNTask from gflownet.data.data_source import DataSource from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.data.sampling_iterator import SQLiteLogHook from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger, set_main_process_device from gflownet.utils.multiprocessing_proxy import mp_object_wrapper +from gflownet.utils.sqlite_log import SQLiteLogHook from .config import Config From a64a639bf6ff66c57ded4aa271aff9cf372d272c Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 09:11:41 -0700 Subject: [PATCH 09/21] replace device references with get_worker_device --- src/gflownet/tasks/qm9.py | 10 ++++++---- src/gflownet/tasks/qm9_moo.py | 2 +- src/gflownet/tasks/seh_frag.py | 5 ++++- src/gflownet/tasks/seh_frag_moo.py | 2 +- src/gflownet/trainer.py | 19 +++++++------------ src/gflownet/utils/conditioning.py | 4 ++-- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 0bad429f..266ff77b 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -15,6 +15,7 @@ from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.misc import get_worker_device from gflownet.utils.transforms import to_logreward @@ -30,7 +31,8 @@ def __init__( ): self._wrap_model = wrap_model self.rng = rng - self.models = self.load_task_models(cfg.task.qm9.model_path, torch.device(cfg.device)) + self.device = get_worker_device() + self.models = self.load_task_models(cfg.task.qm9.model_path) self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) self.num_cond_dim = self.temperature_conditional.encoding_size() @@ -60,7 +62,7 @@ def inverse_flat_reward_transform(self, rp): elif self._rtrans == "unit+95p": return (1 - rp + (1 - self._percentile_95)) * self._width + self._min - def load_task_models(self, path, device): + def load_task_models(self, path): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? try: @@ -73,8 +75,8 @@ def load_task_models(self, path, device): "https://storage.googleapis.com/emmanuel-data/models/mxmnet_gap_model.pt", ) gap_model.load_state_dict(state_dict) - gap_model.to(device) - gap_model, self.device = self._wrap_model(gap_model, send_to_device=True) + gap_model.to(self.device) + gap_model = self._wrap_model(gap_model) return {"mxmnet_gap": gap_model} def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 953c48e4..51029e3a 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -326,7 +326,7 @@ def on_validation_end(self, metrics: Dict[str, Any]): return {"topk": TopKMetricCB()} def build_validation_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.model, send_to_device=True) + model = self._wrap_for_mp(self.model) src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) src.do_conditionals_dataset_in_order(self.test_data, self.cfg.algo.valid_num_from_dataset, model) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index c24a8d05..a731a57c 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -17,6 +17,7 @@ from gflownet.models import bengio2021flow from gflownet.online_trainer import StandardOnlineTrainer from gflownet.utils.conditioning import TemperatureConditional +from gflownet.utils.misc import get_worker_device from gflownet.utils.transforms import to_logreward @@ -43,6 +44,7 @@ def __init__( self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) self.num_cond_dim = self.temperature_conditional.encoding_size() + self.device = get_worker_device() def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y) / 8) @@ -52,7 +54,8 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() - model, self.device = self._wrap_model(model, send_to_device=True) + model.to(self.device) + model = self._wrap_model(model, send_to_device=True) return {"seh": model} def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index bbab91b7..ef8def85 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -351,7 +351,7 @@ def on_validation_end(self, metrics: Dict[str, Any]): return callback_dict def build_validation_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.model, send_to_device=True) + model = self._wrap_for_mp(self.model) src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) src.do_conditionals_dataset_in_order(self.test_data, self.cfg.algo.valid_num_from_dataset, model) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index f7406fd7..c8be7312 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -122,11 +122,9 @@ def setup(self): self.setup_algo() self.setup_model() - def _wrap_for_mp(self, obj, send_to_device=False): + def _wrap_for_mp(self, obj): """Wraps an object in a placeholder whose reference can be sent to a data worker process (only if the number of workers is non-zero).""" - if send_to_device: - obj.to(self.device) if self.cfg.num_workers > 0 and obj is not None: wrapper = mp_object_wrapper( obj, @@ -135,9 +133,9 @@ def _wrap_for_mp(self, obj, send_to_device=False): pickle_messages=self.cfg.pickle_mp_messages, ) self.to_terminate.append(wrapper.terminate) - return wrapper.placeholder, torch.device("cpu") + return wrapper.placeholder else: - return obj, self.device + return obj def build_callbacks(self): return {} @@ -153,12 +151,9 @@ def _make_data_loader(self, src): def build_training_data_loader(self) -> DataLoader: # Since the model may be used by a worker in a different process, we need to wrap it. - # The device `dev` returned here is the device that the worker will use to interact with the model; - # normally, if the main process has the model on 'cuda', this will simply be 'cpu' (since workers - # don't have CUDA access). # See implementation_nodes.md for more details. - model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) - replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) + model = self._wrap_for_mp(self.sampling_model) + replay_buffer = self._wrap_for_mp(self.replay_buffer) if self.cfg.replay.use: # None is fine for either value, it will be replaced by num_from_policy, but 0 is not @@ -184,7 +179,7 @@ def build_training_data_loader(self) -> DataLoader: return self._make_data_loader(src) def build_validation_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.model, send_to_device=True) + model = self._wrap_for_mp(self.model) # TODO: we're changing the default, make sure anything that is using test data is adjusted src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) n_drawn = self.cfg.algo.valid_num_from_policy @@ -205,7 +200,7 @@ def build_validation_data_loader(self) -> DataLoader: return self._make_data_loader(src) def build_final_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.model, send_to_device=True) + model = self._wrap_for_mp(self.model) n_drawn = self.cfg.algo.num_from_policy src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py index aece3868..0630be55 100644 --- a/src/gflownet/utils/conditioning.py +++ b/src/gflownet/utils/conditioning.py @@ -12,6 +12,7 @@ from gflownet.config import Config from gflownet.utils import metrics from gflownet.utils.focus_model import TabularFocusModel +from gflownet.utils.misc import get_worker_device from gflownet.utils.transforms import thermometer @@ -142,8 +143,7 @@ def __init__(self, cfg: Config, n_valid: int, rng: np.random.Generator): if focus_type is not None and "learned" in focus_type: if focus_type == "learned-tabular": self.focus_model = TabularFocusModel( - # TODO: proper device propagation - device=torch.device("cpu"), + device=get_worker_device(), n_objectives=cfg.cond.moo.num_objectives, state_space_res=self.cfg.focus_model_state_space_res, ) From 28bcc5946779b89f142b5407aa8230a95aeb52ec Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 09:39:45 -0700 Subject: [PATCH 10/21] little fixes --- src/gflownet/data/data_source.py | 2 +- src/gflownet/tasks/seh_frag.py | 4 ++-- src/gflownet/utils/multiobjective_hooks.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 1f74dc59..1d4f3384 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -170,7 +170,7 @@ def iterator(): # I'm also not a fan of encode_conditional_information, it assumes lots of things about what's passed to # it and the state of the program (e.g. validation mode) cond_info = self.task.encode_conditional_information(torch.stack([data[i] for i in idcs])) - trajs = self.algo.create_training_data_from_own_samples(model, num_samples, cond_info["encoding"], p) + trajs = self.algo.create_training_data_from_own_samples(model, len(idcs), cond_info["encoding"], p) self.set_traj_cond_info(trajs, cond_info) # Attach the cond info to the trajs self.compute_properties(trajs, mark_as_online=True) self.compute_log_rewards(trajs) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index a731a57c..f02f3a44 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -40,11 +40,11 @@ def __init__( ): self._wrap_model = wrap_model self.rng = rng + self.device = get_worker_device() self.models = self._load_task_models() self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) self.num_cond_dim = self.temperature_conditional.encoding_size() - self.device = get_worker_device() def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: return FlatRewards(torch.as_tensor(y) / 8) @@ -55,7 +55,7 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() model.to(self.device) - model = self._wrap_model(model, send_to_device=True) + model = self._wrap_model(model) return {"seh": model} def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 115bef3a..359d1f6b 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -74,7 +74,7 @@ def _hsri(self, x): def _run_pareto_accumulation(self): num_updates = 0 timeouts = 0 - while not self.stop.is_set() or timeouts < 200: + while not self.stop.is_set() and timeouts < 200: try: r, smi, owid = self.pareto_queue.get(block=True, timeout=1) except queue.Empty: From 4811e7c16eb10cd0557030ba1acdf2bda89beb65 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 12:01:45 -0700 Subject: [PATCH 11/21] a few more stragglers --- src/gflownet/data/data_source.py | 2 +- src/gflownet/tasks/qm9.py | 4 +++- src/gflownet/tasks/seh_frag.py | 5 ++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 1d4f3384..d78a2a7f 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -266,7 +266,7 @@ def compute_log_rewards(self, trajs): def send_to_replay(self, trajs): if self.replay_buffer is not None: for t in trajs: - self.replay_buffer.push(t, t["log_rewards"], t["flat_rewards"], t["cond_info"], t["is_valid"]) + self.replay_buffer.push(t, t["log_reward"], t["flat_rewards"], t["cond_info"], t["is_valid"]) def set_traj_cond_info(self, trajs, cond_info): for i in range(len(trajs)): diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 266ff77b..5f489938 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -87,7 +87,9 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) + batch.to( + self.models["mxmnet_gap"].device if hasattr(self.models["mxmnet_gap"], "device") else get_worker_device() + ) preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] preds[preds.isnan()] = 1 preds = ( diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index f02f3a44..e64d642d 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -40,7 +40,6 @@ def __init__( ): self._wrap_model = wrap_model self.rng = rng - self.device = get_worker_device() self.models = self._load_task_models() self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) @@ -54,7 +53,7 @@ def inverse_flat_reward_transform(self, rp): def _load_task_models(self): model = bengio2021flow.load_original_model() - model.to(self.device) + model.to(get_worker_device()) model = self._wrap_model(model) return {"seh": model} @@ -66,7 +65,7 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat def compute_reward_from_graph(self, graphs: List[Data]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) + batch.to(self.models["seh"].device if hasattr(self.models["seh"], "device") else get_worker_device()) preds = self.models["seh"](batch).reshape((-1,)).data.cpu() preds[preds.isnan()] = 0 return self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1,)) From 7d32ac142c64bb7fde180b5adbe7ad7e5ddad6c2 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 23 Feb 2024 15:44:26 -0700 Subject: [PATCH 12/21] proof of concept of using shared pinned buffers --- src/gflownet/algo/graph_sampling.py | 4 +- src/gflownet/config.py | 1 + src/gflownet/data/sampling_iterator.py | 453 ++++++++++++++++++++ src/gflownet/models/graph_transformer.py | 2 + src/gflownet/tasks/seh_frag.py | 1 + src/gflownet/trainer.py | 54 ++- src/gflownet/utils/multiprocessing_proxy.py | 198 ++++++++- 7 files changed, 707 insertions(+), 6 deletions(-) create mode 100644 src/gflownet/data/sampling_iterator.py diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 0db5bcec..3cd253df 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -114,7 +114,9 @@ def not_done(lst): # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. # TODO: compute bck_cat.log_prob(bck_a) when relevant - fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) + batch = self.ctx.collate(torch_graphs) + batch.cond_info = cond_info[not_done_mask] + fwd_cat, *_, log_reward_preds = model(batch.to(dev), None) if random_action_prob > 0: masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks # Device which graphs in the minibatch will get their action randomized diff --git a/src/gflownet/config.py b/src/gflownet/config.py index e8238e97..070d68d1 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -101,6 +101,7 @@ class Config: pickle_mp_messages: bool = False git_hash: Optional[str] = None overwrite_existing_exp: bool = False + mp_buffer_size: Optional[int] = None algo: AlgoConfig = field(default_factory=AlgoConfig) model: ModelConfig = field(default_factory=ModelConfig) opt: OptimizerConfig = field(default_factory=OptimizerConfig) diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py new file mode 100644 index 00000000..8daec21b --- /dev/null +++ b/src/gflownet/data/sampling_iterator.py @@ -0,0 +1,453 @@ +import os +import sqlite3 +from collections.abc import Iterable +from copy import deepcopy +from typing import Callable, List + +import numpy as np +import torch +import torch.nn as nn +import torch.multiprocessing as mp +from rdkit import RDLogger +from torch.utils.data import Dataset, IterableDataset + +from gflownet.data.replay_buffer import ReplayBuffer +from gflownet.envs.graph_building_env import GraphActionCategorical +from gflownet.utils.multiprocessing_proxy import put_into_batch_buffer, SharedPinnedBuffer + + +class SamplingIterator(IterableDataset): + """This class allows us to parallelise and train faster. + + By separating sampling data/the model and building torch geometric + graphs from training the model, we can do the former in different + processes, which is much faster since much of graph construction + is CPU-bound. + + """ + + def __init__( + self, + dataset: Dataset, + model: nn.Module, + ctx, + algo, + task, + device, + batch_size: int = 1, + illegal_action_logreward: float = -50, + ratio: float = 0.5, + stream: bool = True, + replay_buffer: ReplayBuffer = None, + log_dir: str = None, + sample_cond_info: bool = True, + random_action_prob: float = 0.0, + hindsight_ratio: float = 0.0, + init_train_iter: int = 0, + buffer_size: int = None, + num_workers: int = 1, + do_multiple_buffers = True, # If True, each worker has its own buffer; doesn't seem to have much impact either way + ): + """Parameters + ---------- + dataset: Dataset + A dataset instance + model: nn.Module + The model we sample from (must be on CUDA already or share_memory() must be called so that + parameters are synchronized between each worker) + ctx: + The context for the environment, e.g. a MolBuildingEnvContext instance + algo: + The training algorithm, e.g. a TrajectoryBalance instance + task: GFNTask + A Task instance, e.g. a MakeRingsTask instance + device: torch.device + The device the model is on + replay_buffer: ReplayBuffer + The replay buffer for training on past data + batch_size: int + The number of trajectories, each trajectory will be comprised of many graphs, so this is + _not_ the batch size in terms of the number of graphs (that will depend on the task) + illegal_action_logreward: float + The logreward for invalid trajectories + ratio: float + The ratio of offline trajectories in the batch. + stream: bool + If True, data is sampled iid for every batch. Otherwise, this is a normal in-order + dataset iterator. + log_dir: str + If not None, logs each SamplingIterator worker's generated molecules to that file. + sample_cond_info: bool + If True (default), then the dataset is a dataset of points used in offline training. + If False, then the dataset is a dataset of preferences (e.g. used to validate the model) + random_action_prob: float + The probability of taking a random action, passed to the graph sampler + init_train_iter: int + The initial training iteration, incremented and passed to task.sample_conditional_information + """ + self.data = dataset + self.model = model + self.replay_buffer = replay_buffer + self.batch_size = batch_size + self.illegal_action_logreward = illegal_action_logreward + self.offline_batch_size = int(np.ceil(self.batch_size * ratio)) + self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio))) + self.ratio = ratio + self.ctx = ctx + self.algo = algo + self.task = task + self.device = device + self.stream = stream + self.sample_online_once = True # TODO: deprecate this, disallow len(data) == 0 entirely + self.sample_cond_info = sample_cond_info + self.random_action_prob = random_action_prob + self.hindsight_ratio = hindsight_ratio + self.train_it = init_train_iter + self.do_validate_batch = False # Turn this on for debugging + + self.result_buffer_size = buffer_size + self.do_multiple_buffers = do_multiple_buffers + if buffer_size and do_multiple_buffers: + self.result_buffer = [SharedPinnedBuffer(buffer_size) for _ in range(num_workers)] + elif buffer_size: + self.result_buffer = SharedPinnedBuffer(buffer_size) + self.round_robin_cond = mp.Condition() + self.round_robin_counter = torch.zeros(1) + self.round_robin_counter.share_memory_() + + # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) + # then "offline" now refers to cond info and online to x, so no duplication and we don't end + # up with 2*batch_size accidentally + if not sample_cond_info: + self.offline_batch_size = self.online_batch_size = self.batch_size + + # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we + # don't want to initialize per-worker things just yet, such as where the log the worker writes + # to. This must be done in __iter__, which is called by the DataLoader once this instance + # has been copied into a new python process. + self.log_dir = log_dir + self.log = SQLiteLog() + self.log_hooks: List[Callable] = [] + + def add_log_hook(self, hook: Callable): + self.log_hooks.append(hook) + + def _idx_iterator(self): + RDLogger.DisableLog("rdApp.*") + if self.stream: + # If we're streaming data, just sample `offline_batch_size` indices + while True: + yield self.rng.integers(0, len(self.data), self.offline_batch_size) + else: + # Otherwise, figure out which indices correspond to this worker + worker_info = torch.utils.data.get_worker_info() + n = len(self.data) + if n == 0: + yield np.arange(0, 0) + return + assert ( + self.offline_batch_size > 0 + ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" + if worker_info is None: # no multi-processing + start, end, wid = 0, n, -1 + else: # split the data into chunks (per-worker) + nw = worker_info.num_workers + wid = worker_info.id + start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) + bs = self.offline_batch_size + if end - start <= bs: + yield np.arange(start, end) + return + for i in range(start, end - bs, bs): + yield np.arange(i, i + bs) + if i + bs < end: + yield np.arange(i + bs, end) + + def __len__(self): + if self.stream: + return int(1e6) + if len(self.data) == 0 and self.sample_online_once: + return 1 + return len(self.data) + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + self._wid = worker_info.id if worker_info is not None else 0 + # Now that we know we are in a worker instance, we can initialize per-worker things + self.rng = self.algo.rng = self.task.rng = np.random.default_rng(142857 + self._wid) + self.ctx.device = self.device + if self.log_dir is not None: + os.makedirs(self.log_dir, exist_ok=True) + self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" + self.log.connect(self.log_path) + + for idcs in self._idx_iterator(): + num_offline = idcs.shape[0] # This is in [0, self.offline_batch_size] + # Sample conditional info such as temperature, trade-off weights, etc. + + if self.sample_cond_info: + num_online = self.online_batch_size + cond_info = self.task.sample_conditional_information( + num_offline + self.online_batch_size, self.train_it + ) + + # Sample some dataset data + graphs, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], []) + flat_rewards = ( + list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] + ) + trajs = self.algo.create_training_data_from_graphs( + graphs, self.model, cond_info["encoding"][:num_offline], 0 + ) + + else: # If we're not sampling the conditionals, then the idcs refer to listed preferences + num_online = num_offline + num_offline = 0 + cond_info = self.task.encode_conditional_information( + steer_info=torch.stack([self.data[i] for i in idcs]) + ) + trajs, flat_rewards = [], [] + + # Sample some on-policy data + is_valid = torch.ones(num_offline + num_online).bool() + if num_online > 0: + with torch.no_grad(): + trajs += self.algo.create_training_data_from_own_samples( + self.model, + num_online, + cond_info["encoding"][num_offline:], + random_action_prob=self.random_action_prob, + ) + if self.algo.bootstrap_own_reward: + # The model can be trained to predict its own reward, + # i.e. predict the output of cond_info_to_logreward + pred_reward = [i["reward_pred"].cpu().item() for i in trajs[num_offline:]] + flat_rewards += pred_reward + else: + # Otherwise, query the task for flat rewards + valid_idcs = torch.tensor( + [i + num_offline for i in range(num_online) if trajs[i + num_offline]["is_valid"]] + ).long() + # fetch the valid trajectories endpoints + mols = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] + # ask the task to compute their reward + online_flat_rew, m_is_valid = self.task.compute_flat_rewards(mols) + assert ( + online_flat_rew.ndim == 2 + ), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" + # The task may decide some of the mols are invalid, we have to again filter those + valid_idcs = valid_idcs[m_is_valid] + pred_reward = torch.zeros((num_online, online_flat_rew.shape[1])) + pred_reward[valid_idcs - num_offline] = online_flat_rew + is_valid[num_offline:] = False + is_valid[valid_idcs] = True + flat_rewards += list(pred_reward) + # Override the is_valid key in case the task made some mols invalid + for i in range(num_online): + trajs[num_offline + i]["is_valid"] = is_valid[num_offline + i].item() + + # Compute scalar rewards from conditional information & flat rewards + flat_rewards = torch.stack(flat_rewards) + log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + + # Computes some metrics + extra_info = {} + if not self.sample_cond_info: + # If we're using a dataset of preferences, the user may want to know the id of the preference + for i, j in zip(trajs, idcs): + i["data_idx"] = j + # note: we convert back into natural rewards for logging purposes + # (allows to take averages and plot in objective space) + # TODO: implement that per-task (in case they don't apply the same beta and log transformations) + rewards = torch.exp(log_rewards / cond_info["beta"]) + if num_online > 0 and self.log_dir is not None: + self.log_generated( + deepcopy(trajs[num_offline:]), + deepcopy(rewards[num_offline:]), + deepcopy(flat_rewards[num_offline:]), + {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, + ) + if num_online > 0: + extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item() + for hook in self.log_hooks: + extra_info.update( + hook( + deepcopy(trajs[num_offline:]), + deepcopy(rewards[num_offline:]), + deepcopy(flat_rewards[num_offline:]), + {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, + ) + ) + + if self.replay_buffer is not None: + # If we have a replay buffer, we push the online trajectories in it + # and resample immediately such that the "online" data in the batch + # comes from a more stable distribution (try to avoid forgetting) + + # cond_info is a dict, so we need to convert it to a list of dicts + cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)] + + # push the online trajectories in the replay buffer and sample a new 'online' batch + for i in range(num_offline, len(trajs)): + self.replay_buffer.push( + deepcopy(trajs[i]), + deepcopy(log_rewards[i]), + deepcopy(flat_rewards[i]), + deepcopy(cond_info[i]), + deepcopy(is_valid[i]), + ) + replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample( + num_online + ) + + # append the online trajectories to the offline ones + trajs[num_offline:] = replay_trajs + log_rewards[num_offline:] = replay_logr + flat_rewards[num_offline:] = replay_fr + cond_info[num_offline:] = replay_condinfo + is_valid[num_offline:] = replay_valid + + # convert cond_info back to a dict + cond_info = {k: torch.stack([d[k] for d in cond_info]) for k in cond_info[0]} + + if self.hindsight_ratio > 0.0: + # Relabels some of the online trajectories with hindsight + assert hasattr( + self.task, "relabel_condinfo_and_logrewards" + ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" + # samples indexes of trajectories without repeats + hindsight_idxs = torch.randperm(num_online)[: int(num_online * self.hindsight_ratio)] + num_offline + cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( + cond_info, log_rewards, flat_rewards, hindsight_idxs + ) + log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + + # Construct batch + batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) + batch.num_offline = num_offline + batch.num_online = num_online + batch.flat_rewards = flat_rewards + batch.preferences = cond_info.get("preferences", None) + batch.focus_dir = cond_info.get("focus_dir", None) + batch.extra_info = extra_info + # TODO: we could very well just pass the cond_info dict to construct_batch above, + # and the algo can decide what it wants to put in the batch object + + # Only activate for debugging your environment or dataset (e.g. the dataset could be + # generating trajectories with illegal actions) + if self.do_validate_batch: + self.validate_batch(batch, trajs) + + self.train_it += worker_info.num_workers if worker_info is not None else 1 + if self.result_buffer_size and not self.do_multiple_buffers: + with self.round_robin_cond: + self.round_robin_cond.wait_for(lambda: self.round_robin_counter[0] == self._wid) + self.result_buffer.lock.acquire() + yield put_into_batch_buffer(batch, self.result_buffer.buffer) + self.round_robin_counter[0] = (self._wid + 1) % worker_info.num_workers + with self.round_robin_cond: + self.round_robin_cond.notify_all() + elif self.result_buffer_size: + self.result_buffer[self._wid].lock.acquire() + desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) + desc.wid = self._wid + yield desc + else: + yield batch + + def validate_batch(self, batch, trajs): + for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( + [(batch.bck_actions, self.ctx.bck_action_type_order)] + if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") + else [] + ): + mask_cat = GraphActionCategorical( + batch, + [self.model._action_type_to_mask(t, batch) for t in atypes], + [self.model._action_type_to_key[t] for t in atypes], + [None for _ in atypes], + ) + masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) + num_trajs = len(trajs) + batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) + first_graph_idx = torch.zeros_like(batch.traj_lens) + torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) + if masked_action_is_used.sum() != 0: + invalid_idx = masked_action_is_used.argmax().item() + traj_idx = batch_idx[invalid_idx].item() + timestep = invalid_idx - first_graph_idx[traj_idx].item() + raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) + + def log_generated(self, trajs, rewards, flat_rewards, cond_info): + if hasattr(self.ctx, "object_to_log_repr"): + mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] + else: + mols = [""] * len(trajs) + + flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() + rewards = rewards.data.numpy().tolist() + preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() + focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] + + data = [ + [mols[i], rewards[i]] + + flat_rewards[i] + + preferences[i] + + focus_dir[i] + + [cond_info[k][i].item() for k in logged_keys] + for i in range(len(trajs)) + ] + + data_labels = ( + ["smi", "r"] + + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + + [f"pref_{i}" for i in range(len(preferences[0]))] + + [f"focus_{i}" for i in range(len(focus_dir[0]))] + + [f"ci_{k}" for k in logged_keys] + ) + + self.log.insert_many(data, data_labels) + + +class SQLiteLog: + def __init__(self, timeout=300): + """Creates a log instance, but does not connect it to any db.""" + self.is_connected = False + self.db = None + self.timeout = timeout + + def connect(self, db_path: str): + """Connects to db_path + + Parameters + ---------- + db_path: str + The sqlite3 database path. If it does not exist, it will be created. + """ + self.db = sqlite3.connect(db_path, timeout=self.timeout) + cur = self.db.cursor() + self._has_results_table = len( + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() + ) + cur.close() + + def _make_results_table(self, types, names): + type_map = {str: "text", float: "real", int: "real"} + col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) + cur = self.db.cursor() + cur.execute(f"create table results ({col_str})") + self._has_results_table = True + cur.close() + + def insert_many(self, rows, column_names): + assert all( + [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] + ), "rows must only contain scalars" + if not self._has_results_table: + self._make_results_table([type(i) for i in rows[0]], column_names) + cur = self.db.cursor() + cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec + cur.close() + self.db.commit() diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 8c3993f0..6dae55b7 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -245,6 +245,8 @@ def _make_cat(self, g, emb, action_types): ) def forward(self, g: gd.Batch, cond: torch.Tensor): + if cond is None: + cond = g.cond_info node_embeddings, graph_embeddings = self.transf(g, cond) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e64d642d..18db5a2c 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -216,6 +216,7 @@ def main(): config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] + config.mp_buffer_size = None # 32 * 1024 ** 2, trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index c8be7312..12899958 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -22,8 +22,8 @@ from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger, set_main_process_device -from gflownet.utils.multiprocessing_proxy import mp_object_wrapper from gflownet.utils.sqlite_log import SQLiteLogHook +from gflownet.utils.multiprocessing_proxy import mp_object_wrapper, resolve_batch_buffer, BatchDescriptor from .config import Config @@ -131,6 +131,7 @@ def _wrap_for_mp(self, obj): self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, + bb_size=self.cfg.mp_buffer_size, ) self.to_terminate.append(wrapper.terminate) return wrapper.placeholder @@ -140,7 +141,32 @@ def _wrap_for_mp(self, obj): def build_callbacks(self): return {} +<<<<<<< HEAD def _make_data_loader(self, src): +======= + def build_training_data_loader(self) -> DataLoader: + model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) + replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) + iterator = SamplingIterator( + self.training_data, + model, + self.ctx, + self.algo, + self.task, + dev, + batch_size=self.cfg.algo.global_batch_size, + illegal_action_logreward=self.cfg.algo.illegal_action_logreward, + replay_buffer=replay_buffer, + ratio=self.cfg.algo.offline_ratio, + log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), + random_action_prob=self.cfg.algo.train_random_action_prob, + hindsight_ratio=self.cfg.replay.hindsight_ratio, + buffer_size=self.cfg.mp_buffer_size, + num_workers=self.cfg.num_workers, + ) + for hook in self.sampling_hooks: + iterator.add_log_hook(hook) +>>>>>>> proof of concept of using shared pinned buffers return torch.utils.data.DataLoader( src, batch_size=None, @@ -266,6 +292,7 @@ def run(self, logger=None): start = self.cfg.start_at_step + 1 num_training_steps = self.cfg.num_training_steps logger.info("Starting training") +<<<<<<< HEAD start_time = time.time() for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): # the memory fragmentation or allocation keeps growing, how often should we clean up? @@ -274,6 +301,27 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() +======= + import time + t0 = time.time() + times = [] + for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): + if isinstance(batch, BatchDescriptor): + print(f"buffer size was {batch.size / 1024**2:.2f}M") + if train_dl.dataset.do_multiple_buffers: + wid = batch.wid + batch = resolve_batch_buffer(batch, train_dl.dataset.result_buffer[wid].buffer, self.device) + train_dl.dataset.result_buffer[wid].lock.release() + else: + batch = resolve_batch_buffer(batch, train_dl.dataset.result_buffer.buffer, self.device) + train_dl.dataset.result_buffer.lock.release() + else: + batch = batch.to(self.device) + t1 = time.time() + times.append(t1 - t0) + print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") + t0 = t1 +>>>>>>> proof of concept of using shared pinned buffers epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -281,9 +329,13 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue +<<<<<<< HEAD info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) info["time_spent"] = time.time() - start_time start_time = time.time() +======= + info = self.train_batch(batch, epoch_idx, batch_idx, it) +>>>>>>> proof of concept of using shared pinned buffers self.log(info, it, "train") if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index df13b565..618d98f2 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -2,9 +2,169 @@ import queue import threading import traceback +from itertools import chain +import numpy as np import torch import torch.multiprocessing as mp +from torch_geometric.data import Batch + +from gflownet.envs.graph_building_env import GraphActionCategorical + + +class SharedPinnedBuffer: + def __init__(self, size): + self.size = size + self.buffer = torch.empty(size, dtype=torch.uint8) + self.buffer.share_memory_() + self.lock = mp.Lock() + + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) + assert r == 0 + assert self.buffer.is_shared() + assert self.buffer.is_pinned() + + +class BatchDescriptor: + def __init__(self, names, types, shapes, size, other): + self.names = names + self.types = types + self.shapes = shapes + self.size = size + self.other = other + + +class ResultDescriptor: + def __init__(self, names, types, shapes, size, gac_attrs): + self.names = names + self.types = types + self.shapes = shapes + self.size = size + self.gac_attrs = gac_attrs + + +def prod(l): + p = 1 + for i in l: + p *= i + return p + + +def put_into_batch_buffer(batch, buffer): + names = [] + types = [] + shapes = [] + offset = 0 + others = {} + for k, v in chain(batch._store.items(), (("_slice_dict_" + k, v) for k, v in batch._slice_dict.items())): + if not isinstance(v, torch.Tensor): + try: + v = torch.as_tensor(v) + except Exception as e: + others[k] = v + continue + names.append(k) + types.append(v.dtype) + shapes.append(tuple(v.shape)) + numel = v.numel() * v.element_size() + # print('putting', k, v.shape, numel, offset) + buffer[offset : offset + numel] = v.view(-1).view(torch.uint8) + offset += numel + offset += (8 - offset % 8) % 8 # align to 8 bytes + if offset > buffer.shape[0]: + raise ValueError(f"Offset {offset} exceeds buffer size {buffer.shape[0]}") + # print(f'total size: {offset / 1024**2:.3f}M') + # print(batch.batch) + return BatchDescriptor(names, types, shapes, offset, others) + + +def resolve_batch_buffer(descriptor, buffer, device): + offset = 0 + batch = Batch() + batch._slice_dict = {} + cuda_buffer = buffer[: descriptor.size].to(device) # TODO: check if only sending `size` is faster? + for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): + numel = prod(shape) * dtype.itemsize + # print('restoring', name, shape, numel, offset) + if name.startswith("_slice_dict_"): + batch._slice_dict[name[12:]] = cuda_buffer[offset : offset + numel].view(dtype).view(shape) + else: + setattr(batch, name, cuda_buffer[offset : offset + numel].view(dtype).view(shape)) + offset += numel + offset += (8 - offset % 8) % 8 # align to 8 bytes + # print(batch.batch) + # print(f'total size: {offset / 1024**2:.3f}M') + for k, v in descriptor.other.items(): + setattr(batch, k, v) + return batch + + +def put_into_result_buffer(result, buffer): + gac_names = ["logits", "batch", "slice", "masks"] + gac, tensor = result + buffer[: tensor.numel() * tensor.element_size()] = tensor.view(-1).view(torch.uint8) + offset = tensor.numel() * tensor.element_size() + offset += (8 - offset % 8) % 8 # align to 8 bytes + names = ["@per_graph_out"] + types = [tensor.dtype] + shapes = [tensor.shape] + for name in gac_names: + tensors = getattr(gac, name) + for i, x in enumerate(tensors): + # print(f"putting {name}@{i} with shape {x.shape}") + numel = x.numel() * x.element_size() + if numel > 0: + # We need this for a funny reason + # torch.zeros(0)[::2] has a stride of (2,), and is contiguous according to torch + # so, flattening it and then reshaping it will not change the stride, which will + # make view(uint8) complain that the strides are not compatible. + # The batch[::2] happens when creating the categorical and deduplicate_edge_index is True + buffer[offset : offset + numel] = x.flatten().view(torch.uint8) + offset += numel + offset += (8 - offset % 8) % 8 # align to 8 bytes + if offset > buffer.shape[0]: + raise ValueError(f"Offset {offset} exceeds buffer size {buffer.shape[0]}") + names.append(f"{name}@{i}") + types.append(x.dtype) + shapes.append(tuple(x.shape)) + return ResultDescriptor(names, types, shapes, offset, (gac.num_graphs, gac.keys, gac.types)) + + +def resolve_result_buffer(descriptor, buffer, device): + # TODO: models can return multiple GraphActionCategoricals, but we only support one for now + # Would be nice to have something generic (and recursive?) + offset = 0 + tensor = buffer[: descriptor.size].to(device) + if tensor.device == device: # CPU to CPU + # I think we need this? Otherwise when we release the lock, the memory might be overwritten + tensor = tensor.clone() + # Maybe make this a static method, or just overload __new__? + gac = GraphActionCategorical.__new__(GraphActionCategorical) + gac.num_graphs, gac.keys, gac.types = descriptor.gac_attrs + gac.dev = device + gac.logprobs = None + gac._epsilon = 1e-38 + + gac_names = ["logits", "batch", "slice", "masks"] + for i in gac_names: + setattr(gac, i, [None] * len(gac.types)) + + for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): + numel = prod(shape) * dtype.itemsize + if name == "@per_graph_out": + per_graph_out = tensor[offset : offset + numel].view(dtype).view(shape) + else: + name, index = name.split("@") + index = int(index) + if name in gac_names: + getattr(gac, name)[index] = tensor[offset : offset + numel].view(dtype).view(shape) + else: + raise ValueError(f"Unknown result descriptor name: {name}") + offset += numel + offset += (8 - offset % 8) % 8 # align to 8 bytes + # print(f"restored {name} with shape {shape}") + return gac, per_graph_out class MPObjectPlaceholder: @@ -12,11 +172,15 @@ class MPObjectPlaceholder: in a worker process, and translates calls to the object-placeholder into queries for the main process to execute on the real object.""" - def __init__(self, in_queues, out_queues, pickle_messages=False): + def __init__(self, in_queues, out_queues, pickle_messages=False, batch_buffer_size=None): self.qs = in_queues, out_queues self.device = torch.device("cpu") self.pickle_messages = pickle_messages self._is_init = False + self.batch_buffer_size = batch_buffer_size + if batch_buffer_size is not None: + self._batch_buffer = SharedPinnedBuffer(batch_buffer_size) + self._result_buffer = SharedPinnedBuffer(batch_buffer_size) def _check_init(self): if self._is_init: @@ -41,6 +205,9 @@ def decode(self, m): if isinstance(m, Exception): print("Received exception from main process, reraising.") raise m + if isinstance(m, ResultDescriptor): + m = resolve_result_buffer(m, self._result_buffer.buffer, self.device) + self._result_buffer.lock.release() return m def __getattr__(self, name): @@ -53,6 +220,11 @@ def method_wrapper(*a, **kw): def __call__(self, *a, **kw): self._check_init() + if self.batch_buffer_size and len(a) and isinstance(a[0], Batch): + # The lock will be released by the consumer of this buffer once the memory has been transferred to CUDA + self._batch_buffer.lock.acquire() + batch_descriptor = put_into_batch_buffer(a[0], self._batch_buffer.buffer) + a = (batch_descriptor,) + a[1:] self.in_queue.put(self.encode(("__call__", a, kw))) return self.decode(self.out_queue.get()) @@ -75,7 +247,7 @@ class MPObjectProxy: Always passes CPU tensors between processes. """ - def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False): + def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False, bb_size=None): """Construct a multiprocessing object proxy. Parameters @@ -91,11 +263,13 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo If True, pickle messages sent between processes. This reduces load on shared memory, but increases load on CPU. It is recommended to activate this flag if encountering "Too many open files"-type errors. + bb_size: Optional[int] + batch buffer size """ self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.pickle_messages = pickle_messages - self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages) + self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, bb_size) self.obj = obj if hasattr(obj, "parameters"): self.device = next(obj.parameters()).device @@ -109,6 +283,15 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo def encode(self, m): if self.pickle_messages: return pickle.dumps(m) + if ( + self.placeholder.batch_buffer_size + and isinstance(m, (list, tuple)) + and len(m) == 2 + and isinstance(m[0], GraphActionCategorical) + and isinstance(m[1], torch.Tensor) + ): + self.placeholder._result_buffer.lock.acquire() + return put_into_result_buffer(m, self.placeholder._result_buffer.buffer) return m def decode(self, m): @@ -133,6 +316,12 @@ def run(self): break timeouts = 0 attr, args, kwargs = r + if self.placeholder.batch_buffer_size and len(args) and isinstance(args[0], BatchDescriptor): + batch = resolve_batch_buffer(args[0], self.placeholder._batch_buffer.buffer, self.device) + args = (batch,) + args[1:] + # Should this release happen after the call to f()? Are we at risk of overwriting memory that + # is still being used by CUDA? + self.placeholder._batch_buffer.lock.release() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} @@ -143,6 +332,7 @@ def run(self): except Exception as e: result = e exc_str = traceback.format_exc() + print(exc_str) try: pickle.dumps(e) except Exception: @@ -159,7 +349,7 @@ def terminate(self): self.stop.set() -def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): +def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False, bb_size=None): """Construct a multiprocessing object proxy for torch DataLoaders so that it does not need to be copied in every worker's memory. For example, this can be used to wrap a model such that only the main process makes From d4a2a7ddf512895d8c0b2e5bd67fd530395c283f Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 23 Feb 2024 15:52:41 -0700 Subject: [PATCH 13/21] 32mb buffer --- src/gflownet/tasks/seh_frag.py | 2 +- src/gflownet/trainer.py | 40 +++------------------------------- 2 files changed, 4 insertions(+), 38 deletions(-) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 18db5a2c..41e67df4 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -216,7 +216,7 @@ def main(): config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] - config.mp_buffer_size = None # 32 * 1024 ** 2, + config.mp_buffer_size = 32 * 1024 ** 2 trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 12899958..0d758c78 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -141,32 +141,7 @@ def _wrap_for_mp(self, obj): def build_callbacks(self): return {} -<<<<<<< HEAD def _make_data_loader(self, src): -======= - def build_training_data_loader(self) -> DataLoader: - model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) - replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) - iterator = SamplingIterator( - self.training_data, - model, - self.ctx, - self.algo, - self.task, - dev, - batch_size=self.cfg.algo.global_batch_size, - illegal_action_logreward=self.cfg.algo.illegal_action_logreward, - replay_buffer=replay_buffer, - ratio=self.cfg.algo.offline_ratio, - log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), - random_action_prob=self.cfg.algo.train_random_action_prob, - hindsight_ratio=self.cfg.replay.hindsight_ratio, - buffer_size=self.cfg.mp_buffer_size, - num_workers=self.cfg.num_workers, - ) - for hook in self.sampling_hooks: - iterator.add_log_hook(hook) ->>>>>>> proof of concept of using shared pinned buffers return torch.utils.data.DataLoader( src, batch_size=None, @@ -292,8 +267,9 @@ def run(self, logger=None): start = self.cfg.start_at_step + 1 num_training_steps = self.cfg.num_training_steps logger.info("Starting training") -<<<<<<< HEAD start_time = time.time() + t0 = time.time() + times = [] for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): # the memory fragmentation or allocation keeps growing, how often should we clean up? # is changing the allocation strategy helpful? @@ -301,11 +277,6 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() -======= - import time - t0 = time.time() - times = [] - for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): if isinstance(batch, BatchDescriptor): print(f"buffer size was {batch.size / 1024**2:.2f}M") if train_dl.dataset.do_multiple_buffers: @@ -321,7 +292,6 @@ def run(self, logger=None): times.append(t1 - t0) print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") t0 = t1 ->>>>>>> proof of concept of using shared pinned buffers epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -329,13 +299,9 @@ def run(self, logger=None): f"iteration {it} : warming up replay buffer {len(self.replay_buffer)}/{self.replay_buffer.warmup}" ) continue -<<<<<<< HEAD - info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) + info = self.train_batch(batch, epoch_idx, batch_idx, it) info["time_spent"] = time.time() - start_time start_time = time.time() -======= - info = self.train_batch(batch, epoch_idx, batch_idx, it) ->>>>>>> proof of concept of using shared pinned buffers self.log(info, it, "train") if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) From 27dfc23a52c0dfb204669b493860e624e1470d2b Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 7 Mar 2024 13:07:21 -0700 Subject: [PATCH 14/21] add to DataSource --- src/gflownet/data/data_source.py | 19 +- src/gflownet/data/sampling_iterator.py | 453 -------------------- src/gflownet/models/seq_transformer.py | 3 + src/gflownet/trainer.py | 1 + src/gflownet/utils/multiprocessing_proxy.py | 15 +- 5 files changed, 27 insertions(+), 464 deletions(-) delete mode 100644 src/gflownet/data/sampling_iterator.py diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index d78a2a7f..caba7399 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -3,6 +3,7 @@ import numpy as np import torch +import torch.multiprocessing as mp from torch.utils.data import IterableDataset from gflownet import GFNAlgorithm, GFNTask @@ -10,6 +11,7 @@ from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphBuildingEnvContext from gflownet.utils.misc import get_worker_rng +from gflownet.utils.multiprocessing_proxy import SharedPinnedBuffer, put_into_batch_buffer def cycle_call(it): @@ -44,6 +46,7 @@ def __init__( self.global_step_count.share_memory_() self.global_step_count_lock = torch.multiprocessing.Lock() self.current_iter = start_at_step + self.setup_mp_buffers() def add_sampling_hook(self, hook: Callable): """Add a hook that is called when sampling new trajectories. @@ -230,7 +233,7 @@ def create_batch(self, trajs, batch_info): batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) batch.flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) - return batch + return self._maybe_put_in_mp_buffer(batch) def compute_properties(self, trajs, mark_as_online=False): """Sets trajs' flat_rewards and is_valid keys by querying the task.""" @@ -318,3 +321,17 @@ def iterate_indices(self, n, num_samples): yield np.arange(i, i + num_samples) if i + num_samples < end: yield np.arange(i + num_samples, end) + + def setup_mp_buffers(self): + self.result_buffer_size = self.cfg.mp_buffer_size + if self.result_buffer_size: + self.result_buffer = [SharedPinnedBuffer(self.result_buffer_size) for _ in range(self.cfg.num_workers)] + + def _maybe_put_in_mp_buffer(self, batch): + if self.result_buffer_size: + self.result_buffer[self._wid].lock.acquire() + desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) + desc.wid = self._wid + return desc + else: + return batch diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py deleted file mode 100644 index 8daec21b..00000000 --- a/src/gflownet/data/sampling_iterator.py +++ /dev/null @@ -1,453 +0,0 @@ -import os -import sqlite3 -from collections.abc import Iterable -from copy import deepcopy -from typing import Callable, List - -import numpy as np -import torch -import torch.nn as nn -import torch.multiprocessing as mp -from rdkit import RDLogger -from torch.utils.data import Dataset, IterableDataset - -from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphActionCategorical -from gflownet.utils.multiprocessing_proxy import put_into_batch_buffer, SharedPinnedBuffer - - -class SamplingIterator(IterableDataset): - """This class allows us to parallelise and train faster. - - By separating sampling data/the model and building torch geometric - graphs from training the model, we can do the former in different - processes, which is much faster since much of graph construction - is CPU-bound. - - """ - - def __init__( - self, - dataset: Dataset, - model: nn.Module, - ctx, - algo, - task, - device, - batch_size: int = 1, - illegal_action_logreward: float = -50, - ratio: float = 0.5, - stream: bool = True, - replay_buffer: ReplayBuffer = None, - log_dir: str = None, - sample_cond_info: bool = True, - random_action_prob: float = 0.0, - hindsight_ratio: float = 0.0, - init_train_iter: int = 0, - buffer_size: int = None, - num_workers: int = 1, - do_multiple_buffers = True, # If True, each worker has its own buffer; doesn't seem to have much impact either way - ): - """Parameters - ---------- - dataset: Dataset - A dataset instance - model: nn.Module - The model we sample from (must be on CUDA already or share_memory() must be called so that - parameters are synchronized between each worker) - ctx: - The context for the environment, e.g. a MolBuildingEnvContext instance - algo: - The training algorithm, e.g. a TrajectoryBalance instance - task: GFNTask - A Task instance, e.g. a MakeRingsTask instance - device: torch.device - The device the model is on - replay_buffer: ReplayBuffer - The replay buffer for training on past data - batch_size: int - The number of trajectories, each trajectory will be comprised of many graphs, so this is - _not_ the batch size in terms of the number of graphs (that will depend on the task) - illegal_action_logreward: float - The logreward for invalid trajectories - ratio: float - The ratio of offline trajectories in the batch. - stream: bool - If True, data is sampled iid for every batch. Otherwise, this is a normal in-order - dataset iterator. - log_dir: str - If not None, logs each SamplingIterator worker's generated molecules to that file. - sample_cond_info: bool - If True (default), then the dataset is a dataset of points used in offline training. - If False, then the dataset is a dataset of preferences (e.g. used to validate the model) - random_action_prob: float - The probability of taking a random action, passed to the graph sampler - init_train_iter: int - The initial training iteration, incremented and passed to task.sample_conditional_information - """ - self.data = dataset - self.model = model - self.replay_buffer = replay_buffer - self.batch_size = batch_size - self.illegal_action_logreward = illegal_action_logreward - self.offline_batch_size = int(np.ceil(self.batch_size * ratio)) - self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio))) - self.ratio = ratio - self.ctx = ctx - self.algo = algo - self.task = task - self.device = device - self.stream = stream - self.sample_online_once = True # TODO: deprecate this, disallow len(data) == 0 entirely - self.sample_cond_info = sample_cond_info - self.random_action_prob = random_action_prob - self.hindsight_ratio = hindsight_ratio - self.train_it = init_train_iter - self.do_validate_batch = False # Turn this on for debugging - - self.result_buffer_size = buffer_size - self.do_multiple_buffers = do_multiple_buffers - if buffer_size and do_multiple_buffers: - self.result_buffer = [SharedPinnedBuffer(buffer_size) for _ in range(num_workers)] - elif buffer_size: - self.result_buffer = SharedPinnedBuffer(buffer_size) - self.round_robin_cond = mp.Condition() - self.round_robin_counter = torch.zeros(1) - self.round_robin_counter.share_memory_() - - # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) - # then "offline" now refers to cond info and online to x, so no duplication and we don't end - # up with 2*batch_size accidentally - if not sample_cond_info: - self.offline_batch_size = self.online_batch_size = self.batch_size - - # This SamplingIterator instance will be copied by torch DataLoaders for each worker, so we - # don't want to initialize per-worker things just yet, such as where the log the worker writes - # to. This must be done in __iter__, which is called by the DataLoader once this instance - # has been copied into a new python process. - self.log_dir = log_dir - self.log = SQLiteLog() - self.log_hooks: List[Callable] = [] - - def add_log_hook(self, hook: Callable): - self.log_hooks.append(hook) - - def _idx_iterator(self): - RDLogger.DisableLog("rdApp.*") - if self.stream: - # If we're streaming data, just sample `offline_batch_size` indices - while True: - yield self.rng.integers(0, len(self.data), self.offline_batch_size) - else: - # Otherwise, figure out which indices correspond to this worker - worker_info = torch.utils.data.get_worker_info() - n = len(self.data) - if n == 0: - yield np.arange(0, 0) - return - assert ( - self.offline_batch_size > 0 - ), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)" - if worker_info is None: # no multi-processing - start, end, wid = 0, n, -1 - else: # split the data into chunks (per-worker) - nw = worker_info.num_workers - wid = worker_info.id - start, end = int(np.round(n / nw * wid)), int(np.round(n / nw * (wid + 1))) - bs = self.offline_batch_size - if end - start <= bs: - yield np.arange(start, end) - return - for i in range(start, end - bs, bs): - yield np.arange(i, i + bs) - if i + bs < end: - yield np.arange(i + bs, end) - - def __len__(self): - if self.stream: - return int(1e6) - if len(self.data) == 0 and self.sample_online_once: - return 1 - return len(self.data) - - def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - self._wid = worker_info.id if worker_info is not None else 0 - # Now that we know we are in a worker instance, we can initialize per-worker things - self.rng = self.algo.rng = self.task.rng = np.random.default_rng(142857 + self._wid) - self.ctx.device = self.device - if self.log_dir is not None: - os.makedirs(self.log_dir, exist_ok=True) - self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" - self.log.connect(self.log_path) - - for idcs in self._idx_iterator(): - num_offline = idcs.shape[0] # This is in [0, self.offline_batch_size] - # Sample conditional info such as temperature, trade-off weights, etc. - - if self.sample_cond_info: - num_online = self.online_batch_size - cond_info = self.task.sample_conditional_information( - num_offline + self.online_batch_size, self.train_it - ) - - # Sample some dataset data - graphs, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], []) - flat_rewards = ( - list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] - ) - trajs = self.algo.create_training_data_from_graphs( - graphs, self.model, cond_info["encoding"][:num_offline], 0 - ) - - else: # If we're not sampling the conditionals, then the idcs refer to listed preferences - num_online = num_offline - num_offline = 0 - cond_info = self.task.encode_conditional_information( - steer_info=torch.stack([self.data[i] for i in idcs]) - ) - trajs, flat_rewards = [], [] - - # Sample some on-policy data - is_valid = torch.ones(num_offline + num_online).bool() - if num_online > 0: - with torch.no_grad(): - trajs += self.algo.create_training_data_from_own_samples( - self.model, - num_online, - cond_info["encoding"][num_offline:], - random_action_prob=self.random_action_prob, - ) - if self.algo.bootstrap_own_reward: - # The model can be trained to predict its own reward, - # i.e. predict the output of cond_info_to_logreward - pred_reward = [i["reward_pred"].cpu().item() for i in trajs[num_offline:]] - flat_rewards += pred_reward - else: - # Otherwise, query the task for flat rewards - valid_idcs = torch.tensor( - [i + num_offline for i in range(num_online) if trajs[i + num_offline]["is_valid"]] - ).long() - # fetch the valid trajectories endpoints - mols = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] - # ask the task to compute their reward - online_flat_rew, m_is_valid = self.task.compute_flat_rewards(mols) - assert ( - online_flat_rew.ndim == 2 - ), "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" - # The task may decide some of the mols are invalid, we have to again filter those - valid_idcs = valid_idcs[m_is_valid] - pred_reward = torch.zeros((num_online, online_flat_rew.shape[1])) - pred_reward[valid_idcs - num_offline] = online_flat_rew - is_valid[num_offline:] = False - is_valid[valid_idcs] = True - flat_rewards += list(pred_reward) - # Override the is_valid key in case the task made some mols invalid - for i in range(num_online): - trajs[num_offline + i]["is_valid"] = is_valid[num_offline + i].item() - - # Compute scalar rewards from conditional information & flat rewards - flat_rewards = torch.stack(flat_rewards) - log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward - - # Computes some metrics - extra_info = {} - if not self.sample_cond_info: - # If we're using a dataset of preferences, the user may want to know the id of the preference - for i, j in zip(trajs, idcs): - i["data_idx"] = j - # note: we convert back into natural rewards for logging purposes - # (allows to take averages and plot in objective space) - # TODO: implement that per-task (in case they don't apply the same beta and log transformations) - rewards = torch.exp(log_rewards / cond_info["beta"]) - if num_online > 0 and self.log_dir is not None: - self.log_generated( - deepcopy(trajs[num_offline:]), - deepcopy(rewards[num_offline:]), - deepcopy(flat_rewards[num_offline:]), - {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, - ) - if num_online > 0: - extra_info["sampled_reward_avg"] = rewards[num_offline:].mean().item() - for hook in self.log_hooks: - extra_info.update( - hook( - deepcopy(trajs[num_offline:]), - deepcopy(rewards[num_offline:]), - deepcopy(flat_rewards[num_offline:]), - {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, - ) - ) - - if self.replay_buffer is not None: - # If we have a replay buffer, we push the online trajectories in it - # and resample immediately such that the "online" data in the batch - # comes from a more stable distribution (try to avoid forgetting) - - # cond_info is a dict, so we need to convert it to a list of dicts - cond_info = [{k: v[i] for k, v in cond_info.items()} for i in range(num_offline + num_online)] - - # push the online trajectories in the replay buffer and sample a new 'online' batch - for i in range(num_offline, len(trajs)): - self.replay_buffer.push( - deepcopy(trajs[i]), - deepcopy(log_rewards[i]), - deepcopy(flat_rewards[i]), - deepcopy(cond_info[i]), - deepcopy(is_valid[i]), - ) - replay_trajs, replay_logr, replay_fr, replay_condinfo, replay_valid = self.replay_buffer.sample( - num_online - ) - - # append the online trajectories to the offline ones - trajs[num_offline:] = replay_trajs - log_rewards[num_offline:] = replay_logr - flat_rewards[num_offline:] = replay_fr - cond_info[num_offline:] = replay_condinfo - is_valid[num_offline:] = replay_valid - - # convert cond_info back to a dict - cond_info = {k: torch.stack([d[k] for d in cond_info]) for k in cond_info[0]} - - if self.hindsight_ratio > 0.0: - # Relabels some of the online trajectories with hindsight - assert hasattr( - self.task, "relabel_condinfo_and_logrewards" - ), "Hindsight requires the task to implement relabel_condinfo_and_logrewards" - # samples indexes of trajectories without repeats - hindsight_idxs = torch.randperm(num_online)[: int(num_online * self.hindsight_ratio)] + num_offline - cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( - cond_info, log_rewards, flat_rewards, hindsight_idxs - ) - log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward - - # Construct batch - batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards) - batch.num_offline = num_offline - batch.num_online = num_online - batch.flat_rewards = flat_rewards - batch.preferences = cond_info.get("preferences", None) - batch.focus_dir = cond_info.get("focus_dir", None) - batch.extra_info = extra_info - # TODO: we could very well just pass the cond_info dict to construct_batch above, - # and the algo can decide what it wants to put in the batch object - - # Only activate for debugging your environment or dataset (e.g. the dataset could be - # generating trajectories with illegal actions) - if self.do_validate_batch: - self.validate_batch(batch, trajs) - - self.train_it += worker_info.num_workers if worker_info is not None else 1 - if self.result_buffer_size and not self.do_multiple_buffers: - with self.round_robin_cond: - self.round_robin_cond.wait_for(lambda: self.round_robin_counter[0] == self._wid) - self.result_buffer.lock.acquire() - yield put_into_batch_buffer(batch, self.result_buffer.buffer) - self.round_robin_counter[0] = (self._wid + 1) % worker_info.num_workers - with self.round_robin_cond: - self.round_robin_cond.notify_all() - elif self.result_buffer_size: - self.result_buffer[self._wid].lock.acquire() - desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) - desc.wid = self._wid - yield desc - else: - yield batch - - def validate_batch(self, batch, trajs): - for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + ( - [(batch.bck_actions, self.ctx.bck_action_type_order)] - if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order") - else [] - ): - mask_cat = GraphActionCategorical( - batch, - [self.model._action_type_to_mask(t, batch) for t in atypes], - [self.model._action_type_to_key[t] for t in atypes], - [None for _ in atypes], - ) - masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits) - num_trajs = len(trajs) - batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens) - first_graph_idx = torch.zeros_like(batch.traj_lens) - torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:]) - if masked_action_is_used.sum() != 0: - invalid_idx = masked_action_is_used.argmax().item() - traj_idx = batch_idx[invalid_idx].item() - timestep = invalid_idx - first_graph_idx[traj_idx].item() - raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep]) - - def log_generated(self, trajs, rewards, flat_rewards, cond_info): - if hasattr(self.ctx, "object_to_log_repr"): - mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] - else: - mols = [""] * len(trajs) - - flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() - rewards = rewards.data.numpy().tolist() - preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() - logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] - - data = [ - [mols[i], rewards[i]] - + flat_rewards[i] - + preferences[i] - + focus_dir[i] - + [cond_info[k][i].item() for k in logged_keys] - for i in range(len(trajs)) - ] - - data_labels = ( - ["smi", "r"] - + [f"fr_{i}" for i in range(len(flat_rewards[0]))] - + [f"pref_{i}" for i in range(len(preferences[0]))] - + [f"focus_{i}" for i in range(len(focus_dir[0]))] - + [f"ci_{k}" for k in logged_keys] - ) - - self.log.insert_many(data, data_labels) - - -class SQLiteLog: - def __init__(self, timeout=300): - """Creates a log instance, but does not connect it to any db.""" - self.is_connected = False - self.db = None - self.timeout = timeout - - def connect(self, db_path: str): - """Connects to db_path - - Parameters - ---------- - db_path: str - The sqlite3 database path. If it does not exist, it will be created. - """ - self.db = sqlite3.connect(db_path, timeout=self.timeout) - cur = self.db.cursor() - self._has_results_table = len( - cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='results'").fetchall() - ) - cur.close() - - def _make_results_table(self, types, names): - type_map = {str: "text", float: "real", int: "real"} - col_str = ", ".join(f"{name} {type_map[t]}" for t, name in zip(types, names)) - cur = self.db.cursor() - cur.execute(f"create table results ({col_str})") - self._has_results_table = True - cur.close() - - def insert_many(self, rows, column_names): - assert all( - [isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]] - ), "rows must only contain scalars" - if not self._has_results_table: - self._make_results_table([type(i) for i in rows[0]], column_names) - cur = self.db.cursor() - cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec - cur.close() - self.db.commit() diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index 6916366a..751cf290 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -77,6 +77,9 @@ def forward(self, xs: SeqBatch, cond, batched=False): x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device)) pooled_x = x[xs.lens - 1, torch.arange(x.shape[1])] # (batch, nemb) + if cond is None: + cond = xs.cond_info + if self.use_cond: cond_var = self.cond_embed(cond) # (batch, nemb) cond_var = torch.tile(cond_var, (x.shape[0], 1, 1)) if batched else cond_var diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 0d758c78..6bddf98b 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -277,6 +277,7 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() + if isinstance(batch, BatchDescriptor): print(f"buffer size was {batch.size / 1024**2:.2f}M") if train_dl.dataset.do_multiple_buffers: diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 618d98f2..d92bb61f 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -68,14 +68,13 @@ def put_into_batch_buffer(batch, buffer): types.append(v.dtype) shapes.append(tuple(v.shape)) numel = v.numel() * v.element_size() - # print('putting', k, v.shape, numel, offset) buffer[offset : offset + numel] = v.view(-1).view(torch.uint8) offset += numel offset += (8 - offset % 8) % 8 # align to 8 bytes if offset > buffer.shape[0]: - raise ValueError(f"Offset {offset} exceeds buffer size {buffer.shape[0]}") - # print(f'total size: {offset / 1024**2:.3f}M') - # print(batch.batch) + raise ValueError( + f"Offset {offset} exceeds buffer size {buffer.shape[0]}. Try increasing `cfg.mp_buffer_size`." + ) return BatchDescriptor(names, types, shapes, offset, others) @@ -83,18 +82,16 @@ def resolve_batch_buffer(descriptor, buffer, device): offset = 0 batch = Batch() batch._slice_dict = {} - cuda_buffer = buffer[: descriptor.size].to(device) # TODO: check if only sending `size` is faster? + cuda_buffer = buffer[: descriptor.size].to(device) for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): numel = prod(shape) * dtype.itemsize - # print('restoring', name, shape, numel, offset) if name.startswith("_slice_dict_"): batch._slice_dict[name[12:]] = cuda_buffer[offset : offset + numel].view(dtype).view(shape) else: setattr(batch, name, cuda_buffer[offset : offset + numel].view(dtype).view(shape)) offset += numel offset += (8 - offset % 8) % 8 # align to 8 bytes - # print(batch.batch) - # print(f'total size: {offset / 1024**2:.3f}M') + for k, v in descriptor.other.items(): setattr(batch, k, v) return batch @@ -112,7 +109,6 @@ def put_into_result_buffer(result, buffer): for name in gac_names: tensors = getattr(gac, name) for i, x in enumerate(tensors): - # print(f"putting {name}@{i} with shape {x.shape}") numel = x.numel() * x.element_size() if numel > 0: # We need this for a funny reason @@ -163,7 +159,6 @@ def resolve_result_buffer(descriptor, buffer, device): raise ValueError(f"Unknown result descriptor name: {name}") offset += numel offset += (8 - offset % 8) % 8 # align to 8 bytes - # print(f"restored {name} with shape {shape}") return gac, per_graph_out From e9f1dc13824a0ba621c7f71e05cd2005bee4945d Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 8 Mar 2024 09:16:53 -0700 Subject: [PATCH 15/21] various fixes --- docs/implementation_notes.md | 14 +++++++++ src/gflownet/algo/config.py | 1 + src/gflownet/data/data_source.py | 17 +++++++---- src/gflownet/tasks/seh_frag.py | 5 ++-- src/gflownet/trainer.py | 32 +++++++++++---------- src/gflownet/utils/multiprocessing_proxy.py | 27 +++++++++++++---- 6 files changed, 68 insertions(+), 28 deletions(-) diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 600bb1d4..696095f1 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -34,3 +34,17 @@ The code contains a specific categorical distribution type for graph actions, `G Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor. The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution. + +## Multiprocessing + +We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks. + +Because training models involves sampling them, the worker processes need to be able to call the models. This is done by passing a wrapped model (and possibly wrapped replay buffer) to the workers, using `gflownet.utils.multiprocessing_proxy`. These wrappers ensure that model calls are routed to the main worker process, where the model lives (e.g. in CUDA), and that the returned values are properly serialized and sent back to the worker process. These wrappers are also designed to be API-compatible with models, e.g. `model(input)` or `model.method(input)` will work as expected, regardless of whether `model` is a torch module or a wrapper. Note that it is only possible to call methods on these wrappers, direct attribute access is not supported. + +Note that the workers do not use CUDA, therefore have to work entirely on CPU, but the code is designed to be somewhat agnostic to this fact. By using `get_worker_device`, code can be written without assuming too much; again, calls such as `model(input)` will work as expected. + +On message serialization, naively sending batches of data and results (`Batch` and `GraphActionCategorical`) through multiprocessing queues is fairly inefficient. Torch tries to be smart and will use shared memory for tensors that are sent through queues, which unfortunately is very slow because creating these shared memory files is slow, and because `Data` `Batch`es tend to contain lots of small tensors, which is not a good fit for shared memory. + +We implement two solutions to this problem (in order of preference): +- using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but that the size of the largest possible batch/return value is known in advance. This is currently only implemented for `Batch` inputs and `(GraphActionCategorical, Tensor)` outputs. +- using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creating of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle. \ No newline at end of file diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index e2576982..70780359 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -138,6 +138,7 @@ class AlgoConfig: train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 sampling_tau: float = 0.0 + compute_log_n: bool = False tb: TBConfig = field(default_factory=TBConfig) moql: MOQLConfig = field(default_factory=MOQLConfig) a2c: A2CConfig = field(default_factory=A2CConfig) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index caba7399..eabad5cd 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -5,11 +5,12 @@ import torch import torch.multiprocessing as mp from torch.utils.data import IterableDataset +from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphBuildingEnvContext +from gflownet.envs.graph_building_env import GraphBuildingEnvContext, GraphActionCategorical from gflownet.utils.misc import get_worker_rng from gflownet.utils.multiprocessing_proxy import SharedPinnedBuffer, put_into_batch_buffer @@ -228,7 +229,7 @@ def create_batch(self, trajs, batch_info): if "focus_dir" in trajs[0]: batch.focus_dir = torch.stack([t["focus_dir"] for t in trajs]) - if self.ctx.has_n(): # Does this go somewhere else? Require a flag? Might not be cheap to compute + if self.ctx.has_n() and self.cfg.algo.compute_log_n: log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) @@ -323,12 +324,18 @@ def iterate_indices(self, n, num_samples): yield np.arange(i + num_samples, end) def setup_mp_buffers(self): - self.result_buffer_size = self.cfg.mp_buffer_size - if self.result_buffer_size: - self.result_buffer = [SharedPinnedBuffer(self.result_buffer_size) for _ in range(self.cfg.num_workers)] + if self.cfg.num_workers > 0: + self.result_buffer_size = self.cfg.mp_buffer_size + if self.result_buffer_size: + self.result_buffer = [SharedPinnedBuffer(self.result_buffer_size) for _ in range(self.cfg.num_workers)] + else: + self.result_buffer_size = None def _maybe_put_in_mp_buffer(self, batch): if self.result_buffer_size: + if not (isinstance(batch, Batch)): + warnings.warn(f"Expected a Batch object, but got {type(batch)}. " "Not using mp buffers.") + return batch self.result_buffer[self._wid].lock.acquire() desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) desc.wid = self._wid diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 41e67df4..b5179124 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -208,9 +208,10 @@ def main(): config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" config.device = "cuda" if torch.cuda.is_available() else "cpu" config.overwrite_existing_exp = True + config.algo.num_from_policy = 64 config.num_training_steps = 1_00 - config.validate_every = 20 - config.num_final_gen_steps = 10 + config.validate_every = 2000 + config.num_final_gen_steps = 0 config.num_workers = 8 config.opt.lr_decay = 20_000 config.algo.sampling_tau = 0.99 diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 6bddf98b..a71cb835 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -3,7 +3,7 @@ import pathlib import shutil import time -from typing import Any, Callable, Dict, List, Optional, Protocol +from typing import Any, Callable, Dict, List, Optional, Protocol, Union import numpy as np import torch @@ -15,6 +15,7 @@ from rdkit import RDLogger from torch import Tensor from torch.utils.data import DataLoader, Dataset +from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.data.data_source import DataSource @@ -181,8 +182,7 @@ def build_training_data_loader(self) -> DataLoader: def build_validation_data_loader(self) -> DataLoader: model = self._wrap_for_mp(self.model) - # TODO: we're changing the default, make sure anything that is using test data is adjusted - src = DataSource(self.cfg, self.ctx, self.algo, self.task, is_algo_eval=True) + n_drawn = self.cfg.algo.valid_num_from_policy n_from_dataset = self.cfg.algo.valid_num_from_dataset @@ -247,6 +247,16 @@ def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0 info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} + def _maybe_resolve_batch_buffer(self, batch: Union[Batch, BatchDescriptor], dl: DataLoader) -> Batch: + if isinstance(batch, BatchDescriptor): + print(f"buffer size was {batch.size / 1024**2:.2f}M") + wid = batch.wid + batch = resolve_batch_buffer(batch, dl.dataset.result_buffer[wid].buffer, self.device) + dl.dataset.result_buffer[wid].lock.release() + else: + batch = batch.to(self.device) + return batch + def run(self, logger=None): """Trains the GFN for `num_training_steps` minibatches, performing validation every `validate_every` minibatches. @@ -277,18 +287,8 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() - - if isinstance(batch, BatchDescriptor): - print(f"buffer size was {batch.size / 1024**2:.2f}M") - if train_dl.dataset.do_multiple_buffers: - wid = batch.wid - batch = resolve_batch_buffer(batch, train_dl.dataset.result_buffer[wid].buffer, self.device) - train_dl.dataset.result_buffer[wid].lock.release() - else: - batch = resolve_batch_buffer(batch, train_dl.dataset.result_buffer.buffer, self.device) - train_dl.dataset.result_buffer.lock.release() - else: - batch = batch.to(self.device) + _bd = batch + batch = self._maybe_resolve_batch_buffer(batch, train_dl) t1 = time.time() times.append(t1 - t0) print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") @@ -309,6 +309,7 @@ def run(self, logger=None): if valid_freq > 0 and it % valid_freq == 0: for batch in valid_dl: + batch = self._maybe_resolve_batch_buffer(batch, valid_dl) info = self.evaluate_batch(batch.to(self.device), epoch_idx, batch_idx) self.log(info, it, "valid") logger.info(f"validation - iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) @@ -329,6 +330,7 @@ def run(self, logger=None): range(num_training_steps + 1, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): + batch = self._maybe_resolve_batch_buffer(batch, final_dl) if hasattr(batch, "extra_info"): for k, v in batch.extra_info.items(): if k not in final_info: diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index d92bb61f..0f2c6aef 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -8,6 +8,7 @@ import torch import torch.multiprocessing as mp from torch_geometric.data import Batch +import warnings from gflownet.envs.graph_building_env import GraphActionCategorical @@ -18,13 +19,24 @@ def __init__(self, size): self.buffer = torch.empty(size, dtype=torch.uint8) self.buffer.share_memory_() self.lock = mp.Lock() - - cudart = torch.cuda.cudart() - r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) - assert r == 0 + self.do_unreg = False + + if not self.buffer.is_pinned(): + # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to + # pin it again; doing so will raise a CUDA error + cudart = torch.cuda.cudart() + r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) + assert r == 0 + self.do_unreg = True # But then we need to unregister it later assert self.buffer.is_shared() assert self.buffer.is_pinned() + def __del__(self): + if self.do_unreg and torch.utils.data.get_worker_info() is None: + cudart = torch.cuda.cudart() + r = cudart.cudaHostUnregister(self.buffer.data_ptr()) + assert r == 0 + class BatchDescriptor: def __init__(self, names, types, shapes, size, other): @@ -82,7 +94,10 @@ def resolve_batch_buffer(descriptor, buffer, device): offset = 0 batch = Batch() batch._slice_dict = {} + # Seems legit to send just a 0-starting slice, because it should be pinned as well (and timing this vs sending + # the whole buffer, it seems to be the marginally faster option) cuda_buffer = buffer[: descriptor.size].to(device) + for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): numel = prod(shape) * dtype.itemsize if name.startswith("_slice_dict_"): @@ -300,7 +315,7 @@ def to_cpu(self, i): def run(self): timeouts = 0 - while not self.stop.is_set() or timeouts < 500: + while not self.stop.is_set() and timeouts < 5 / 1e-5: for qi, q in enumerate(self.in_queues): try: r = self.decode(q.get(True, 1e-5)) @@ -379,4 +394,4 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - return MPObjectProxy(obj, num_workers, cast_types, pickle_messages) + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages, bb_size=bb_size) From c048e77ff0c5fc0cbd11c128f9c392195ec4e091 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 8 Mar 2024 13:52:54 -0700 Subject: [PATCH 16/21] major simplification by reusing pickling mechanisms --- src/gflownet/data/data_source.py | 22 +- src/gflownet/envs/seq_building_env.py | 1 + src/gflownet/tasks/seh_frag.py | 4 +- src/gflownet/trainer.py | 24 +- src/gflownet/utils/multiprocessing_proxy.py | 266 +++++++------------- 5 files changed, 108 insertions(+), 209 deletions(-) diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index eabad5cd..009b0fd4 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -3,16 +3,15 @@ import numpy as np import torch -import torch.multiprocessing as mp from torch.utils.data import IterableDataset from torch_geometric.data import Batch from gflownet import GFNAlgorithm, GFNTask from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer -from gflownet.envs.graph_building_env import GraphBuildingEnvContext, GraphActionCategorical +from gflownet.envs.graph_building_env import GraphBuildingEnvContext from gflownet.utils.misc import get_worker_rng -from gflownet.utils.multiprocessing_proxy import SharedPinnedBuffer, put_into_batch_buffer +from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer def cycle_call(it): @@ -325,20 +324,17 @@ def iterate_indices(self, n, num_samples): def setup_mp_buffers(self): if self.cfg.num_workers > 0: - self.result_buffer_size = self.cfg.mp_buffer_size - if self.result_buffer_size: - self.result_buffer = [SharedPinnedBuffer(self.result_buffer_size) for _ in range(self.cfg.num_workers)] + self.mp_buffer_size = self.cfg.mp_buffer_size + if self.mp_buffer_size: + self.result_buffer = [SharedPinnedBuffer(self.mp_buffer_size) for _ in range(self.cfg.num_workers)] else: - self.result_buffer_size = None + self.mp_buffer_size = None def _maybe_put_in_mp_buffer(self, batch): - if self.result_buffer_size: + if self.mp_buffer_size: if not (isinstance(batch, Batch)): - warnings.warn(f"Expected a Batch object, but got {type(batch)}. " "Not using mp buffers.") + warnings.warn(f"Expected a Batch object, but got {type(batch)}. Not using mp buffers.") return batch - self.result_buffer[self._wid].lock.acquire() - desc = put_into_batch_buffer(batch, self.result_buffer[self._wid].buffer) - desc.wid = self._wid - return desc + return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid) else: return batch diff --git a/src/gflownet/envs/seq_building_env.py b/src/gflownet/envs/seq_building_env.py index b8189690..c5e77bab 100644 --- a/src/gflownet/envs/seq_building_env.py +++ b/src/gflownet/envs/seq_building_env.py @@ -69,6 +69,7 @@ def __init__(self, seqs: List[torch.Tensor], pad: int): # Since we're feeding this batch object to graph-based algorithms, we have to use this naming, but this # is the total number of timesteps. self.num_graphs = self.lens.sum().item() + self.cond_info: torch.Tensor # May be set later def to(self, device): for name in dir(self): diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index b5179124..74d997a7 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -214,10 +214,12 @@ def main(): config.num_final_gen_steps = 0 config.num_workers = 8 config.opt.lr_decay = 20_000 + config.opt.clip_grad_type = "total_norm" config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] - config.mp_buffer_size = 32 * 1024 ** 2 + config.mp_buffer_size = 32 * 1024**2 + # config.pickle_mp_messages = True trial = SEHFragTrainer(config) trial.run() diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index a71cb835..1e3352e6 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -23,8 +23,8 @@ from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger, set_main_process_device +from gflownet.utils.multiprocessing_proxy import BufferUnpickler, mp_object_wrapper from gflownet.utils.sqlite_log import SQLiteLogHook -from gflownet.utils.multiprocessing_proxy import mp_object_wrapper, resolve_batch_buffer, BatchDescriptor from .config import Config @@ -132,7 +132,7 @@ def _wrap_for_mp(self, obj): self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, - bb_size=self.cfg.mp_buffer_size, + sb_size=self.cfg.mp_buffer_size, ) self.to_terminate.append(wrapper.terminate) return wrapper.placeholder @@ -182,7 +182,7 @@ def build_training_data_loader(self) -> DataLoader: def build_validation_data_loader(self) -> DataLoader: model = self._wrap_for_mp(self.model) - + n_drawn = self.cfg.algo.valid_num_from_policy n_from_dataset = self.cfg.algo.valid_num_from_dataset @@ -247,13 +247,11 @@ def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0 info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} - def _maybe_resolve_batch_buffer(self, batch: Union[Batch, BatchDescriptor], dl: DataLoader) -> Batch: - if isinstance(batch, BatchDescriptor): - print(f"buffer size was {batch.size / 1024**2:.2f}M") - wid = batch.wid - batch = resolve_batch_buffer(batch, dl.dataset.result_buffer[wid].buffer, self.device) - dl.dataset.result_buffer[wid].lock.release() - else: + def _maybe_resolve_shared_buffer(self, batch: Union[Batch, tuple, list], dl: DataLoader) -> Batch: + if dl.dataset.mp_buffer_size > 0 and isinstance(batch, (tuple, list)): + batch, wid = batch + batch = BufferUnpickler(dl.dataset.result_buffer[wid], batch, self.device).load() + elif isinstance(batch, Batch): batch = batch.to(self.device) return batch @@ -288,7 +286,7 @@ def run(self, logger=None): gc.collect() torch.cuda.empty_cache() _bd = batch - batch = self._maybe_resolve_batch_buffer(batch, train_dl) + batch = self._maybe_resolve_shared_buffer(batch, train_dl) t1 = time.time() times.append(t1 - t0) print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") @@ -309,7 +307,7 @@ def run(self, logger=None): if valid_freq > 0 and it % valid_freq == 0: for batch in valid_dl: - batch = self._maybe_resolve_batch_buffer(batch, valid_dl) + batch = self._maybe_resolve_shared_buffer(batch, valid_dl) info = self.evaluate_batch(batch.to(self.device), epoch_idx, batch_idx) self.log(info, it, "valid") logger.info(f"validation - iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) @@ -330,7 +328,7 @@ def run(self, logger=None): range(num_training_steps + 1, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): - batch = self._maybe_resolve_batch_buffer(batch, final_dl) + batch = self._maybe_resolve_shared_buffer(batch, final_dl) if hasattr(batch, "extra_info"): for k, v in batch.extra_info.items(): if k not in final_info: diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 0f2c6aef..86a3d971 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -1,16 +1,12 @@ +import io import pickle import queue import threading import traceback -from itertools import chain +from pickle import Pickler, Unpickler, UnpicklingError -import numpy as np import torch import torch.multiprocessing as mp -from torch_geometric.data import Batch -import warnings - -from gflownet.envs.graph_building_env import GraphActionCategorical class SharedPinnedBuffer: @@ -22,7 +18,7 @@ def __init__(self, size): self.do_unreg = False if not self.buffer.is_pinned(): - # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to + # Sometimes torch will create an already pinned (page aligned) buffer, so we don't need to # pin it again; doing so will raise a CUDA error cudart = torch.cuda.cudart() r = cudart.cudaHostRegister(self.buffer.data_ptr(), self.buffer.numel() * self.buffer.element_size(), 0) @@ -38,159 +34,82 @@ def __del__(self): assert r == 0 -class BatchDescriptor: - def __init__(self, names, types, shapes, size, other): - self.names = names - self.types = types - self.shapes = shapes - self.size = size - self.other = other +class _BufferPicklerSentinel: + pass -class ResultDescriptor: - def __init__(self, names, types, shapes, size, gac_attrs): - self.names = names - self.types = types - self.shapes = shapes - self.size = size - self.gac_attrs = gac_attrs +class BufferPickler(Pickler): + def __init__(self, buf: SharedPinnedBuffer): + self._f = io.BytesIO() + super().__init__(self._f) + self.buf = buf + # The lock will be released by the consumer of this buffer once the memory has been transferred to the device + self.buf.lock.acquire() + self.buf_offset = 0 + + def persistent_id(self, v): + if not isinstance(v, torch.Tensor): + return None + numel = v.numel() * v.element_size() + start = self.buf_offset + if numel > 0: + self.buf.buffer[start : start + numel] = v.view(-1).view(torch.uint8) + self.buf_offset += numel + self.buf_offset += (8 - self.buf_offset % 8) % 8 # align to 8 bytes + return (_BufferPicklerSentinel, (start, tuple(v.shape), v.dtype)) + + def dumps(self, obj): + self.dump(obj) + return (self._f.getvalue(), self.buf_offset) + + +class BufferUnpickler(Unpickler): + def __init__(self, buf: SharedPinnedBuffer, data, device): + self._f, total_size = io.BytesIO(data[0]), data[1] + super().__init__(self._f) + self.buf = buf + self.target_buf = buf.buffer[:total_size].to(device) + + def load_tensor(self, offset, shape, dtype): + numel = prod(shape) * dtype.itemsize + tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape) + return tensor + def persistent_load(self, pid): + if isinstance(pid, tuple): + sentinel, (offset, shape, dtype) = pid + if sentinel is _BufferPicklerSentinel: + return self.load_tensor(offset, shape, dtype) + return UnpicklingError("Invalid persistent id") -def prod(l): + def load(self): + r = super().load() + # We're done with this buffer, release it for the next consumer + self.buf.lock.release() + return r + + +def prod(ns): p = 1 - for i in l: + for i in ns: p *= i return p -def put_into_batch_buffer(batch, buffer): - names = [] - types = [] - shapes = [] - offset = 0 - others = {} - for k, v in chain(batch._store.items(), (("_slice_dict_" + k, v) for k, v in batch._slice_dict.items())): - if not isinstance(v, torch.Tensor): - try: - v = torch.as_tensor(v) - except Exception as e: - others[k] = v - continue - names.append(k) - types.append(v.dtype) - shapes.append(tuple(v.shape)) - numel = v.numel() * v.element_size() - buffer[offset : offset + numel] = v.view(-1).view(torch.uint8) - offset += numel - offset += (8 - offset % 8) % 8 # align to 8 bytes - if offset > buffer.shape[0]: - raise ValueError( - f"Offset {offset} exceeds buffer size {buffer.shape[0]}. Try increasing `cfg.mp_buffer_size`." - ) - return BatchDescriptor(names, types, shapes, offset, others) - - -def resolve_batch_buffer(descriptor, buffer, device): - offset = 0 - batch = Batch() - batch._slice_dict = {} - # Seems legit to send just a 0-starting slice, because it should be pinned as well (and timing this vs sending - # the whole buffer, it seems to be the marginally faster option) - cuda_buffer = buffer[: descriptor.size].to(device) - - for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): - numel = prod(shape) * dtype.itemsize - if name.startswith("_slice_dict_"): - batch._slice_dict[name[12:]] = cuda_buffer[offset : offset + numel].view(dtype).view(shape) - else: - setattr(batch, name, cuda_buffer[offset : offset + numel].view(dtype).view(shape)) - offset += numel - offset += (8 - offset % 8) % 8 # align to 8 bytes - - for k, v in descriptor.other.items(): - setattr(batch, k, v) - return batch - - -def put_into_result_buffer(result, buffer): - gac_names = ["logits", "batch", "slice", "masks"] - gac, tensor = result - buffer[: tensor.numel() * tensor.element_size()] = tensor.view(-1).view(torch.uint8) - offset = tensor.numel() * tensor.element_size() - offset += (8 - offset % 8) % 8 # align to 8 bytes - names = ["@per_graph_out"] - types = [tensor.dtype] - shapes = [tensor.shape] - for name in gac_names: - tensors = getattr(gac, name) - for i, x in enumerate(tensors): - numel = x.numel() * x.element_size() - if numel > 0: - # We need this for a funny reason - # torch.zeros(0)[::2] has a stride of (2,), and is contiguous according to torch - # so, flattening it and then reshaping it will not change the stride, which will - # make view(uint8) complain that the strides are not compatible. - # The batch[::2] happens when creating the categorical and deduplicate_edge_index is True - buffer[offset : offset + numel] = x.flatten().view(torch.uint8) - offset += numel - offset += (8 - offset % 8) % 8 # align to 8 bytes - if offset > buffer.shape[0]: - raise ValueError(f"Offset {offset} exceeds buffer size {buffer.shape[0]}") - names.append(f"{name}@{i}") - types.append(x.dtype) - shapes.append(tuple(x.shape)) - return ResultDescriptor(names, types, shapes, offset, (gac.num_graphs, gac.keys, gac.types)) - - -def resolve_result_buffer(descriptor, buffer, device): - # TODO: models can return multiple GraphActionCategoricals, but we only support one for now - # Would be nice to have something generic (and recursive?) - offset = 0 - tensor = buffer[: descriptor.size].to(device) - if tensor.device == device: # CPU to CPU - # I think we need this? Otherwise when we release the lock, the memory might be overwritten - tensor = tensor.clone() - # Maybe make this a static method, or just overload __new__? - gac = GraphActionCategorical.__new__(GraphActionCategorical) - gac.num_graphs, gac.keys, gac.types = descriptor.gac_attrs - gac.dev = device - gac.logprobs = None - gac._epsilon = 1e-38 - - gac_names = ["logits", "batch", "slice", "masks"] - for i in gac_names: - setattr(gac, i, [None] * len(gac.types)) - - for name, dtype, shape in zip(descriptor.names, descriptor.types, descriptor.shapes): - numel = prod(shape) * dtype.itemsize - if name == "@per_graph_out": - per_graph_out = tensor[offset : offset + numel].view(dtype).view(shape) - else: - name, index = name.split("@") - index = int(index) - if name in gac_names: - getattr(gac, name)[index] = tensor[offset : offset + numel].view(dtype).view(shape) - else: - raise ValueError(f"Unknown result descriptor name: {name}") - offset += numel - offset += (8 - offset % 8) % 8 # align to 8 bytes - return gac, per_graph_out - - class MPObjectPlaceholder: """This class can be used for example as a model or dataset placeholder in a worker process, and translates calls to the object-placeholder into queries for the main process to execute on the real object.""" - def __init__(self, in_queues, out_queues, pickle_messages=False, batch_buffer_size=None): + def __init__(self, in_queues, out_queues, pickle_messages=False, shared_buffer_size=None): self.qs = in_queues, out_queues self.device = torch.device("cpu") self.pickle_messages = pickle_messages self._is_init = False - self.batch_buffer_size = batch_buffer_size - if batch_buffer_size is not None: - self._batch_buffer = SharedPinnedBuffer(batch_buffer_size) - self._result_buffer = SharedPinnedBuffer(batch_buffer_size) + self.shared_buffer_size = shared_buffer_size + if shared_buffer_size is not None: + self._buffer_to_main = SharedPinnedBuffer(shared_buffer_size) + self._buffer_from_main = SharedPinnedBuffer(shared_buffer_size) def _check_init(self): if self._is_init: @@ -205,19 +124,20 @@ def _check_init(self): self._is_init = True def encode(self, m): + if self.shared_buffer_size: + return BufferPickler(self._buffer_to_main).dumps(m) if self.pickle_messages: return pickle.dumps(m) return m def decode(self, m): + if self.shared_buffer_size: + m = BufferUnpickler(self._buffer_from_main, m, self.device).load() if self.pickle_messages: m = pickle.loads(m) if isinstance(m, Exception): print("Received exception from main process, reraising.") raise m - if isinstance(m, ResultDescriptor): - m = resolve_result_buffer(m, self._result_buffer.buffer, self.device) - self._result_buffer.lock.release() return m def __getattr__(self, name): @@ -230,11 +150,6 @@ def method_wrapper(*a, **kw): def __call__(self, *a, **kw): self._check_init() - if self.batch_buffer_size and len(a) and isinstance(a[0], Batch): - # The lock will be released by the consumer of this buffer once the memory has been transferred to CUDA - self._batch_buffer.lock.acquire() - batch_descriptor = put_into_batch_buffer(a[0], self._batch_buffer.buffer) - a = (batch_descriptor,) + a[1:] self.in_queue.put(self.encode(("__call__", a, kw))) return self.decode(self.out_queue.get()) @@ -257,7 +172,7 @@ class MPObjectProxy: Always passes CPU tensors between processes. """ - def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False, bb_size=None): + def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bool = False, sb_size=None): """Construct a multiprocessing object proxy. Parameters @@ -273,13 +188,14 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo If True, pickle messages sent between processes. This reduces load on shared memory, but increases load on CPU. It is recommended to activate this flag if encountering "Too many open files"-type errors. - bb_size: Optional[int] - batch buffer size + sb_size: Optional[int] + shared buffer size """ self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.pickle_messages = pickle_messages - self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, bb_size) + self.use_shared_buffer = sb_size is not None + self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, sb_size) self.obj = obj if hasattr(obj, "parameters"): self.device = next(obj.parameters()).device @@ -291,20 +207,16 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.thread.start() def encode(self, m): + if self.use_shared_buffer: + return BufferPickler(self.placeholder._buffer_from_main).dumps(m) if self.pickle_messages: return pickle.dumps(m) - if ( - self.placeholder.batch_buffer_size - and isinstance(m, (list, tuple)) - and len(m) == 2 - and isinstance(m[0], GraphActionCategorical) - and isinstance(m[1], torch.Tensor) - ): - self.placeholder._result_buffer.lock.acquire() - return put_into_result_buffer(m, self.placeholder._result_buffer.buffer) return m def decode(self, m): + if self.use_shared_buffer: + return BufferUnpickler(self.placeholder._buffer_to_main, m, self.device).load() + if self.pickle_messages: return pickle.loads(m) return m @@ -326,12 +238,6 @@ def run(self): break timeouts = 0 attr, args, kwargs = r - if self.placeholder.batch_buffer_size and len(args) and isinstance(args[0], BatchDescriptor): - batch = resolve_batch_buffer(args[0], self.placeholder._batch_buffer.buffer, self.device) - args = (batch,) + args[1:] - # Should this release happen after the call to f()? Are we at risk of overwriting memory that - # is still being used by CUDA? - self.placeholder._batch_buffer.lock.release() f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] kwargs = {k: i.to(self.device) if isinstance(i, self.cuda_types) else i for k, i in kwargs.items()} @@ -359,34 +265,30 @@ def terminate(self): self.stop.set() -def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False, bb_size=None): +def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False, sb_size=None): """Construct a multiprocessing object proxy for torch DataLoaders so that it does not need to be copied in every worker's memory. For example, this can be used to wrap a model such that only the main process makes cuda calls by forwarding data through the model, or a replay buffer such that the new data is pushed in from the worker processes but only the main process has to hold the full buffer in memory. - self.out_queues[qi].put(self.encode(msg)) - elif isinstance(result, dict): - msg = {k: self.to_cpu(i) for k, i in result.items()} - self.out_queues[qi].put(self.encode(msg)) - else: - msg = self.to_cpu(result) - self.out_queues[qi].put(self.encode(msg)) Parameters ---------- obj: any python object to be proxied (typically a torch.nn.Module or ReplayBuffer) - Lives in the main process to which method calls are passed + Lives in the main process to which method calls are passed num_workers: int Number of DataLoader workers cast_types: tuple Types that will be cast to cuda when received as arguments of method calls. torch.Tensor is cast by default. pickle_messages: bool - If True, pickle messages sent between processes. This reduces load on shared - memory, but increases load on CPU. It is recommended to activate this flag if - encountering "Too many open files"-type errors. + If True, pickle messages sent between processes. This reduces load on shared + memory, but increases load on CPU. It is recommended to activate this flag if + encountering "Too many open files"-type errors. + sb_size: Optional[int] + If not None, creates a shared buffer of this size for sending tensors between processes. + Note, this will allocate two buffers of this size (one for sending, the other for receiving). Returns ------- @@ -394,4 +296,4 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - return MPObjectProxy(obj, num_workers, cast_types, pickle_messages, bb_size=bb_size) + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages, sb_size=sb_size) From acfe07075eec5460d8c7abbb1a48985c37bc9d22 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Mon, 11 Mar 2024 10:00:34 -0600 Subject: [PATCH 17/21] memory copy + fixes and doc --- src/gflownet/config.py | 4 ++ src/gflownet/data/data_source.py | 3 +- src/gflownet/trainer.py | 9 ++-- src/gflownet/utils/multiprocessing_proxy.py | 51 ++++++++++++++++----- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index 070d68d1..4a3035f9 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -78,6 +78,10 @@ class Config: The hostname of the machine on which the experiment is run pickle_mp_messages : bool Whether to pickle messages sent between processes (only relevant if num_workers > 0) + mp_buffer_size : Optional[int] + If specified, use a buffer of this size for passing tensors between processes. + Note that this is only relevant if num_workers > 0. + Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers. git_hash : Optional[str] The git hash of the current commit overwrite_existing_exp : bool diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 009b0fd4..c795ef32 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -10,6 +10,7 @@ from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphBuildingEnvContext +from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import get_worker_rng from gflownet.utils.multiprocessing_proxy import BufferPickler, SharedPinnedBuffer @@ -332,7 +333,7 @@ def setup_mp_buffers(self): def _maybe_put_in_mp_buffer(self, batch): if self.mp_buffer_size: - if not (isinstance(batch, Batch)): + if not (isinstance(batch, (Batch, SeqBatch))): warnings.warn(f"Expected a Batch object, but got {type(batch)}. Not using mp buffers.") return batch return (BufferPickler(self.result_buffer[self._wid]).dumps(batch), self._wid) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 1e3352e6..bc1e2645 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -247,11 +247,13 @@ def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0 info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} - def _maybe_resolve_shared_buffer(self, batch: Union[Batch, tuple, list], dl: DataLoader) -> Batch: - if dl.dataset.mp_buffer_size > 0 and isinstance(batch, (tuple, list)): + def _maybe_resolve_shared_buffer( + self, batch: Union[Batch, SeqBatch, tuple, list], dl: DataLoader + ) -> Union[Batch, SeqBatch]: + if dl.dataset.mp_buffer_size and isinstance(batch, (tuple, list)): batch, wid = batch batch = BufferUnpickler(dl.dataset.result_buffer[wid], batch, self.device).load() - elif isinstance(batch, Batch): + elif isinstance(batch, (Batch, SeqBatch)): batch = batch.to(self.device) return batch @@ -285,7 +287,6 @@ def run(self, logger=None): if it % 1024 == 0: gc.collect() torch.cuda.empty_cache() - _bd = batch batch = self._maybe_resolve_shared_buffer(batch, train_dl) t1 = time.time() times.append(t1 - t0) diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 86a3d971..7ed734bd 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -28,10 +28,11 @@ def __init__(self, size): assert self.buffer.is_pinned() def __del__(self): - if self.do_unreg and torch.utils.data.get_worker_info() is None: - cudart = torch.cuda.cudart() - r = cudart.cudaHostUnregister(self.buffer.data_ptr()) - assert r == 0 + if torch.utils.data.get_worker_info() is None: + if self.do_unreg: + cudart = torch.cuda.cudart() + r = cudart.cudaHostUnregister(self.buffer.data_ptr()) + assert r == 0 class _BufferPicklerSentinel: @@ -43,7 +44,8 @@ def __init__(self, buf: SharedPinnedBuffer): self._f = io.BytesIO() super().__init__(self._f) self.buf = buf - # The lock will be released by the consumer of this buffer once the memory has been transferred to the device + # The lock will be released by the consumer (BufferUnpickler) of this buffer once + # the memory has been transferred to the device and copied self.buf.lock.acquire() self.buf_offset = 0 @@ -51,12 +53,30 @@ def persistent_id(self, v): if not isinstance(v, torch.Tensor): return None numel = v.numel() * v.element_size() + if self.buf_offset + numel > self.buf.size: + raise RuntimeError( + f"Tried to allocate {self.buf_offset + numel} bytes in a buffer of size {self.buf.size}. " + "Consider increasing cfg.mp_buffer_size" + ) start = self.buf_offset + shape = tuple(v.shape) + if v.ndim > 0 and v.stride(-1) != 1 or not v.is_contiguous(): + v = v.contiguous().reshape(-1) + if v.ndim > 0 and v.stride(-1) != 1: + # We're still not contiguous, this unfortunately happens occasionally, e.g.: + # x = torch.arange(10).reshape((10, 1)) + # y = x.T[::2].T + # y.stride(), y.is_contiguous(), y.contiguous().stride() + # -> (1, 2), True, (1, 2) + v = v.flatten() + 0 + # I don't know if this comes from my misunderstanding of strides or if it's a bug in torch + # but either way torch will refuse to view this tensor as a uint8 tensor, so we have to + 0 + # to force torch to materialize it into a new tensor (it may otherwise be lazy and not materialize) if numel > 0: - self.buf.buffer[start : start + numel] = v.view(-1).view(torch.uint8) + self.buf.buffer[start : start + numel] = v.flatten().view(torch.uint8) self.buf_offset += numel self.buf_offset += (8 - self.buf_offset % 8) % 8 # align to 8 bytes - return (_BufferPicklerSentinel, (start, tuple(v.shape), v.dtype)) + return (_BufferPicklerSentinel, (start, shape, v.dtype)) def dumps(self, obj): self.dump(obj) @@ -68,11 +88,19 @@ def __init__(self, buf: SharedPinnedBuffer, data, device): self._f, total_size = io.BytesIO(data[0]), data[1] super().__init__(self._f) self.buf = buf - self.target_buf = buf.buffer[:total_size].to(device) + self.target_buf = buf.buffer[:total_size].to(device) + 0 + # Why the `+ 0`? Unfortunately, we have no way to know exactly when the consumer of the object we're + # unpickling will be done using the buffer underlying the tensor, so we have to create a copy. + # If we don't and another consumer starts using the buffer, and this consumer transfers this pinned + # buffer to the GPU, the first consumer's tensors will be corrupted, because (depending on the CUDA + # memory manager) the pinned buffer will transfer to the same GPU location. + # Hopefully, especially if the target device is the GPU, the copy will be fast and/or async. + # Note that this could be fixed by using one buffer for each worker, but that would be significantly + # more memory usage. def load_tensor(self, offset, shape, dtype): numel = prod(shape) * dtype.itemsize - tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape) + tensor: torch.Tensor = self.target_buf[offset : offset + numel].view(dtype).view(shape) return tensor def persistent_load(self, pid): @@ -107,7 +135,7 @@ def __init__(self, in_queues, out_queues, pickle_messages=False, shared_buffer_s self.pickle_messages = pickle_messages self._is_init = False self.shared_buffer_size = shared_buffer_size - if shared_buffer_size is not None: + if shared_buffer_size: self._buffer_to_main = SharedPinnedBuffer(shared_buffer_size) self._buffer_from_main = SharedPinnedBuffer(shared_buffer_size) @@ -194,7 +222,7 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.in_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers + 1)] # type: ignore self.pickle_messages = pickle_messages - self.use_shared_buffer = sb_size is not None + self.use_shared_buffer = bool(sb_size) self.placeholder = MPObjectPlaceholder(self.in_queues, self.out_queues, pickle_messages, sb_size) self.obj = obj if hasattr(obj, "parameters"): @@ -226,7 +254,6 @@ def to_cpu(self, i): def run(self): timeouts = 0 - while not self.stop.is_set() and timeouts < 5 / 1e-5: for qi, q in enumerate(self.in_queues): try: From 907ffcd513b9531351fce45d56023f602f09deba Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 8 May 2024 16:58:25 -0600 Subject: [PATCH 18/21] fix global_cfg + opt_Z when there's no Z --- src/gflownet/__init__.py | 3 +++ src/gflownet/algo/advantage_actor_critic.py | 7 +++-- src/gflownet/algo/envelope_q_learning.py | 27 +++++++++++-------- src/gflownet/algo/flow_matching.py | 5 ++-- src/gflownet/algo/multiobjective_reinforce.py | 3 ++- src/gflownet/algo/soft_q_learning.py | 7 +++-- src/gflownet/algo/trajectory_balance.py | 13 +++++---- src/gflownet/models/graph_transformer.py | 10 +++---- src/gflownet/models/seq_transformer.py | 2 +- src/gflownet/online_trainer.py | 20 +++++++++----- 10 files changed, 58 insertions(+), 39 deletions(-) diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index 6cb8f979..5415ecd8 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -23,6 +23,9 @@ class GFNAlgorithm: def step(self): self.updates += 1 # This isn't used anywhere? + def set_is_eval(self, is_eval: bool): + self.is_eval = is_eval + def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index 7077a9d1..7ce05a81 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -3,6 +3,7 @@ import torch_geometric.data as gd from torch import Tensor +from gflownet import GFNAlgorithm from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device @@ -10,7 +11,7 @@ from .graph_sampling import GraphSampler -class A2C: +class A2C(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -36,6 +37,7 @@ def __init__( The experiment configuration """ + self.global_cfg = cfg # TODO: this belongs in the base class self.ctx = ctx self.env = env self.max_len = cfg.algo.max_len @@ -149,7 +151,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values - policy, per_state_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + policy, per_state_preds = model(batch) V = per_state_preds[:, 0] G = rewards[batch_idx] # The return is the terminal reward everywhere, we're using gamma==1 G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid object diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 9bfc3345..4798600b 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -5,6 +5,7 @@ from torch import Tensor from torch_scatter import scatter +from gflownet import GFNAlgorithm from gflownet.config import Config from gflownet.envs.graph_building_env import ( GraphActionCategorical, @@ -39,24 +40,24 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective num_layers=num_layers, num_heads=num_heads, ) - num_final = num_emb * 2 + num_final = num_emb num_mlp_layers = 0 self.emb2add_node = mlp(num_final, num_emb, env_ctx.num_new_node_values * num_objectives, num_mlp_layers) # Edge attr logits are "sided", so we will compute both sides independently self.emb2set_edge_attr = mlp( num_emb + num_final, num_emb, env_ctx.num_edge_attr_logits // 2 * num_objectives, num_mlp_layers ) - self.emb2stop = mlp(num_emb * 3, num_emb, num_objectives, num_mlp_layers) - self.emb2reward = mlp(num_emb * 3, num_emb, 1, num_mlp_layers) + self.emb2stop = mlp(num_emb * 2, num_emb, num_objectives, num_mlp_layers) + self.emb2reward = mlp(num_emb * 2, num_emb, 1, num_mlp_layers) self.edge2emb = mlp(num_final, num_emb, num_emb, num_mlp_layers) self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) self.action_type_order = env_ctx.action_type_order self.mask_value = -10 self.num_objectives = num_objectives - def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): + def forward(self, g: gd.Batch, output_Qs=False): """See `GraphTransformer` for argument values""" - node_embeddings, graph_embeddings = self.transf(g, cond) + node_embeddings, graph_embeddings = self.transf(g) # On `::2`, edges are duplicated to make graphs undirected, only take the even ones e_row, e_col = g.edge_index[:, ::2] edge_emb = self.edge2emb(node_embeddings[e_row] + node_embeddings[e_col]) @@ -86,7 +87,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): # Compute the greedy policy # See algo.envelope_q_learning.EnvelopeQLearning.compute_batch_losses for further explanations # TODO: this makes assumptions about how conditional vectors are created! Not robust to upstream changes - w = cond[:, -self.num_objectives :] + w = g.cond_info[:, -self.num_objectives :] w_dot_Q = [ (qi.reshape((qi.shape[0], qi.shape[1] // w.shape[1], w.shape[1])) * w[b][:, None, :]).sum(2) for qi, b in zip(cat.logits, cat.batch) @@ -122,8 +123,9 @@ def __init__(self, env_ctx, num_emb=64, num_layers=3, num_heads=2, num_objective self.action_type_order = env_ctx.action_type_order self.num_objectives = num_objectives - def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): - node_embeddings, graph_embeddings = self.transf(g, cond) + def forward(self, g: gd.Batch, output_Qs=False): + cond = g.cond_info + node_embeddings, graph_embeddings = self.transf(g) ne_row, ne_col = g.non_edge_index # On `::2`, edges are duplicated to make graphs undirected, only take the even ones e_row, e_col = g.edge_index[:, ::2] @@ -156,7 +158,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor, output_Qs=False): return cat, r_pred -class EnvelopeQLearning: +class EnvelopeQLearning(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -182,6 +184,7 @@ def __init__( cfg: Config The experiment configuration """ + self.global_cfg = cfg self.ctx = ctx self.env = env self.task = task @@ -314,7 +317,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values # Q(s,a,omega) - fwd_cat, per_state_preds = model(batch, cond_info[batch_idx], output_Qs=True) + batch.cond_info = cond_info[batch_idx] + fwd_cat, per_state_preds = model(batch, output_Qs=True) Q_omega = fwd_cat.logits # reshape to List[shape: (num in all graphs, num actions on T, num_objectives) | for all types T] Q_omega = [i.reshape((i.shape[0], i.shape[1] // num_objectives, num_objectives)) for i in Q_omega] @@ -323,7 +327,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: batchp = batch.batch_prime batchp_num_trajs = int(batchp.traj_lens.shape[0]) batchp_batch_idx = torch.arange(batchp_num_trajs, device=dev).repeat_interleave(batchp.traj_lens) - fwd_cat_prime, per_state_preds = model(batchp, batchp.cond_info[batchp_batch_idx], output_Qs=True) + batchp.cond_info = batchp.cond_info[batchp_batch_idx] + fwd_cat_prime, per_state_preds = model(batchp, output_Qs=True) Q_omega_prime = fwd_cat_prime.logits # We've repeated everything N_omega times, so we can reshape the same way as above but with # an extra N_omega first dimension diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py index 33c436bf..c75c1ce4 100644 --- a/src/gflownet/algo/flow_matching.py +++ b/src/gflownet/algo/flow_matching.py @@ -46,7 +46,7 @@ def __init__( # in a number of settings the regular loss is more stable. self.fm_balanced_loss = cfg.algo.fm.balanced_loss self.fm_leaf_coef = cfg.algo.fm.leaf_coef - self.correct_idempotent: bool = self.correct_idempotent or cfg.algo.fm.correct_idempotent + self.correct_idempotent: bool = cfg.algo.fm.correct_idempotent def construct_batch(self, trajs, cond_info, log_rewards): """Construct a batch from a list of trajectories and their information @@ -149,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Query the model for Fsa. The model will output a GraphActionCategorical, but we will # simply interpret the logits as F(s, a). Conveniently the policy of a GFN is the softmax of # log F(s,a) so we don't have to change anything in the sampling routines. - cat, graph_out = model(batch, batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)]) + batch.cond_info = batch.cond_info[torch.cat([parents_traj_idx, states_traj_idx], 0)] + cat, graph_out = model(batch) # We compute \sum_{s,a : T(s,a)=s'} F(s,a), first we index all the parent's outputs by the # parent actions. To do so we reuse the log_prob mechanism, but specify that the logprobs # tensor is actually just the logits (which we chose to interpret as edge flows F(s,a). We diff --git a/src/gflownet/algo/multiobjective_reinforce.py b/src/gflownet/algo/multiobjective_reinforce.py index b1a636de..52314d03 100644 --- a/src/gflownet/algo/multiobjective_reinforce.py +++ b/src/gflownet/algo/multiobjective_reinforce.py @@ -34,7 +34,8 @@ def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, n batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) # Forward pass of the model, returns a GraphActionCategorical and the optional bootstrap predictions - fwd_cat, log_reward_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + fwd_cat, log_reward_preds = model(batch) # This is the log prob of each action in the trajectory log_prob = fwd_cat.log_prob(batch.actions) diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index a9d61aaa..99730279 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -4,13 +4,14 @@ from torch import Tensor from torch_scatter import scatter +from gflownet import GFNAlgorithm from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config from gflownet.envs.graph_building_env import GraphBuildingEnv, GraphBuildingEnvContext, generate_forward_trajectory from gflownet.utils.misc import get_worker_device -class SoftQLearning: +class SoftQLearning(GFNAlgorithm): def __init__( self, env: GraphBuildingEnv, @@ -33,6 +34,7 @@ def __init__( cfg: Config The experiment configuration """ + self.global_cfg = cfg self.ctx = ctx self.env = env self.max_len = cfg.algo.max_len @@ -147,7 +149,8 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # Forward pass of the model, returns a GraphActionCategorical and per object predictions # Here we will interpret the logits of the fwd_cat as Q values - Q, per_state_preds = model(batch, cond_info[batch_idx]) + batch.cond_info = cond_info[batch_idx] + Q, per_state_preds = model(batch) if self.do_q_prime_correction: # First we need to estimate V_soft. We will use q_a' = \pi diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index eac57cc6..8713cde3 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -109,7 +109,7 @@ def __init__( """ self.ctx = ctx self.env = env - self.global_cfg = cfg + self.global_cfg = cfg # TODO: this belongs in the base class self.cfg = cfg.algo.tb self.max_len = cfg.algo.max_len self.max_nodes = cfg.algo.max_nodes @@ -147,9 +147,6 @@ def __init__( self._subtb_max_len = self.global_cfg.algo.max_len + 2 self._init_subtb(get_worker_device()) - def set_is_eval(self, is_eval: bool): - self.is_eval = is_eval - def create_training_data_from_own_samples( self, model: TrajectoryBalanceModel, @@ -402,12 +399,14 @@ def compute_batch_losses( # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). if self.cfg.do_parameterize_p_b: - fwd_cat, bck_cat, per_graph_out = model(batch, batched_cond_info) + batch.cond_info = batched_cond_info + fwd_cat, bck_cat, per_graph_out = model(batch) else: if self.model_is_autoregressive: - fwd_cat, per_graph_out = model(batch, cond_info, batched=True) + fwd_cat, per_graph_out = model(batch, batched=True) else: - fwd_cat, per_graph_out = model(batch, batched_cond_info) + batch.cond_info = batched_cond_info + fwd_cat, per_graph_out = model(batch) # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index a010b22a..b84980dc 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -90,7 +90,7 @@ def __init__( ) ) - def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): + def forward(self, g: gd.Batch): """Forward pass Parameters @@ -112,7 +112,7 @@ def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): x = g.x o = self.x2h(x) e = self.e2h(g.edge_attr) - c = self.c2h(cond if cond is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) + c = self.c2h(g.cond_info if g.cond_info is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) num_total_nodes = g.x.shape[0] # Augment the edges with a new edge to the conditioning # information node. This new node is connected to every node @@ -255,10 +255,8 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap types=action_types, ) - def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): - if cond is None: - cond = g.cond_info - node_embeddings, graph_embeddings = self.transf(g, cond) + def forward(self, g: gd.Batch): + node_embeddings, graph_embeddings = self.transf(g) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): ne_row, ne_col = g.non_edge_index diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index b694fad2..ce696d2d 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -65,7 +65,7 @@ def logZ(self, cond_info: Optional[torch.Tensor]): return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device)) return self._logZ(cond_info) - def forward(self, xs: SeqBatch, cond, batched=False): + def forward(self, xs: SeqBatch, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions. Parameters diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 103acc95..13dd0e48 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -82,12 +82,17 @@ def setup(self): else: Z_params = [] non_Z_params = list(self.model.parameters()) + self.opt = self._opt(non_Z_params) - self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) - self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( - self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) - ) + + if Z_params: + self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) + self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( + self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) + ) + else: + self.opt_Z = None self.sampling_tau = self.cfg.algo.sampling_tau if self.sampling_tau > 0: @@ -124,10 +129,11 @@ def step(self, loss: Tensor): g1 = model_grad_norm(self.model) self.opt.step() self.opt.zero_grad() - self.opt_Z.step() - self.opt_Z.zero_grad() self.lr_sched.step() - self.lr_sched_Z.step() + if self.opt_Z is not None: + self.opt_Z.step() + self.opt_Z.zero_grad() + self.lr_sched_Z.step() if self.sampling_tau > 0: for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) From 60722a777ad27b44ea5de245acfb9118ca57d541 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 9 May 2024 08:17:28 -0600 Subject: [PATCH 19/21] fix entropy when masks are used --- src/gflownet/envs/graph_building_env.py | 6 ++++-- src/gflownet/models/seq_transformer.py | 4 +--- src/gflownet/trainer.py | 1 + tests/test_graph_building_env.py | 22 +++++++++++++++++++++- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index b10b228d..601d7bd5 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -887,10 +887,12 @@ def entropy(self, logprobs=None): """ if logprobs is None: logprobs = self.logsoftmax() + masks = self.action_masks if self.action_masks is not None else [None] * len(logprobs) entropy = -sum( [ - scatter(i * i.exp(), b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) - for i, b in zip(logprobs, self.batch) + scatter(im, b, dim=0, dim_size=self.num_graphs, reduce="sum").sum(1) + for i, b, m in zip(logprobs, self.batch, masks) + for im in [i.masked_fill(m == 0.0, 0) if m is not None else i] ] ) return entropy diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index ce696d2d..8ecb8919 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -83,9 +83,7 @@ def forward(self, xs: SeqBatch, batched=False): x = self.encoder(x, src_key_padding_mask=xs.mask, mask=generate_square_subsequent_mask(x.shape[0]).to(x.device)) pooled_x = x[xs.lens - 1, torch.arange(x.shape[1])] # (batch, nemb) - if cond is None: - cond = xs.cond_info - + cond = xs.cond_info if self.use_cond: cond_var = self.cond_embed(cond) # (batch, nemb) cond_var = torch.tile(cond_var, (x.shape[0], 1, 1)) if batched else cond_var diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index c537b0ce..fcd078a8 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -219,6 +219,7 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: tick = time.time() self.model.train() try: + loss = info = None loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): raise ValueError("loss is not finite") diff --git a/tests/test_graph_building_env.py b/tests/test_graph_building_env.py index e9184cbd..adf41120 100644 --- a/tests/test_graph_building_env.py +++ b/tests/test_graph_building_env.py @@ -123,4 +123,24 @@ def test_log_prob(): def test_entropy(): cat = make_test_cat() - cat.entropy() + entropy = cat.entropy() + assert torch.isfinite(entropy).all() and entropy.shape == (3,) and (entropy > 0).all() + + cat.action_masks = [ + torch.tensor([[0], [1], [1.0]]), + ((torch.arange(cat.logits[1].numel()) % 2) == 0).float().reshape(cat.logits[1].shape), + torch.tensor([[1, 0, 1], [0, 1, 1.0]]), + ] + entropy = cat.entropy() + assert torch.isfinite(entropy).all() and entropy.shape == (3,) and (entropy > 0).all() + + +def test_entropy_grad(): + # Purposefully large values to test extremal behaviors + logits = torch.tensor([[100, 101, -102, 95, 10, 20, 72]]).float() + logits.requires_grad_(True) + batch = Batch.from_data_list([Data(x=torch.ones((1, 10)), y=torch.ones((2, 6)))], follow_batch=["y"]) + cat = GraphActionCategorical(batch, [logits[:, :3], logits[:, 3:].reshape(2, 2)], [None, "y"], [None, None]) + cat._epsilon = 0 + (grad_gac,) = torch.autograd.grad(cat.entropy(), logits, retain_graph=True) + assert torch.isfinite(grad_gac).all() From f859640929fa33c7542be33981631a75c387d9be Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 9 May 2024 08:36:40 -0600 Subject: [PATCH 20/21] small fixes --- docs/implementation_notes.md | 5 +---- src/gflownet/algo/config.py | 1 - src/gflownet/config.py | 2 +- src/gflownet/tasks/seh_frag.py | 5 +---- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 4b85f2e4..ba63e708 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -53,7 +53,7 @@ The data used for training GFlowNets can come from a variety of sources. `DataSo - Generating new trajectories (w.r.t a fixed dataset of conditioning goals) - Evaluating the model's likelihood on trajectories from a fixed, offline dataset -## Multiprocessing +## Multiprocessing We use the multiprocessing features of torch's `DataLoader` to parallelize data generation and featurization. This is done by setting the `num_workers` (via `cfg.num_workers`) parameter of the `DataLoader` to a value greater than 0. Because workers cannot (easily) use a CUDA handle, we have to resort to a number of tricks. @@ -66,6 +66,3 @@ On message serialization, naively sending batches of data and results (`Batch` a We implement two solutions to this problem (in order of preference): - using `SharedPinnedBuffer`s, which are shared tensors of fixed size (`cfg.mp_buffer_size`), but initialized once and pinned. This is the fastest solution, but requires that the size of the largest possible batch/return value is known in advance. This should work for any message, but has only been tested with `Batch` and `GraphActionCategorical` messages. - using `cfg.pickle_mp_messages`, which simply serializes messages with `pickle`. This prevents the creation of lots of shared memory files, but is slower than the `SharedPinnedBuffer` solution. This should work for any message that `pickle` can handle. - - - diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index c5b1ea8c..4dd9cbfe 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -196,7 +196,6 @@ class AlgoConfig(StrictDataClass): train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 sampling_tau: float = 0.0 - compute_log_n: bool = False tb: TBConfig = field(default_factory=TBConfig) moql: MOQLConfig = field(default_factory=MOQLConfig) a2c: A2CConfig = field(default_factory=A2CConfig) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index d72bd567..86b225f0 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -80,7 +80,7 @@ class Config(StrictDataClass): pickle_mp_messages : bool Whether to pickle messages sent between processes (only relevant if num_workers > 0) mp_buffer_size : Optional[int] - If specified, use a buffer of this size for passing tensors between processes. + If specified, use a buffer of this size in bytes for passing tensors between processes. Note that this is only relevant if num_workers > 0. Also note that this will allocate `num_workers + 2 * number of wrapped objects` buffers. git_hash : Optional[str] diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 84da11c0..de4b72f0 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -129,6 +129,7 @@ class SEHFragTrainer(StandardOnlineTrainer): def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False + cfg.mp_buffer_size = 32 * 1024**2 # 32 MB cfg.num_workers = 8 cfg.opt.learning_rate = 1e-4 cfg.opt.weight_decay = 1e-8 @@ -195,18 +196,14 @@ def main(): config.log_dir = f"./logs/debug_run_seh_frag_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" config.device = "cuda" if torch.cuda.is_available() else "cpu" config.overwrite_existing_exp = True - config.algo.num_from_policy = 64 config.num_training_steps = 1_00 config.validate_every = 20 config.num_final_gen_steps = 10 config.num_workers = 1 config.opt.lr_decay = 20_000 - config.opt.clip_grad_type = "total_norm" config.algo.sampling_tau = 0.99 config.cond.temperature.sample_dist = "uniform" config.cond.temperature.dist_params = [0, 64.0] - config.mp_buffer_size = 32 * 1024**2 - # config.pickle_mp_messages = True trial = SEHFragTrainer(config) trial.run() From d536233ecde108e8a3dad5bb7640f9b75c2771c4 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 9 May 2024 08:59:47 -0600 Subject: [PATCH 21/21] removing timing prints --- src/gflownet/trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index fcd078a8..4d90fa3a 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -279,8 +279,6 @@ def run(self, logger=None): num_training_steps = self.cfg.num_training_steps logger.info("Starting training") start_time = time.time() - t0 = time.time() - times = [] for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): # the memory fragmentation or allocation keeps growing, how often should we clean up? # is changing the allocation strategy helpful? @@ -289,10 +287,6 @@ def run(self, logger=None): gc.collect() torch.cuda.empty_cache() batch = self._maybe_resolve_shared_buffer(batch, train_dl) - t1 = time.time() - times.append(t1 - t0) - print(f"iteration {it} : {t1 - t0:.2f} s, average: {np.mean(times):.2f} s") - t0 = t1 epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: