Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 39 additions & 28 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,13 @@
import tensorrt as trt
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh

logger = logging.getLogger(__name__)

def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
# this is kept at the application level, when mpirun is used to run the application
def initialize_distributed_env(rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
Expand All @@ -50,9 +31,6 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)
Expand All @@ -66,16 +44,49 @@ def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=2950
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
return device_mesh, world_size, rank


def cleanup_distributed_env():
"""Clean up distributed process group to prevent resource leaks."""
if dist.is_initialized():
dist.destroy_process_group()


def check_tensor_parallel_device_number(world_size: int) -> None:
if world_size % 2 != 0:
raise ValueError(
f"TP examples require even number of GPUs, but got {world_size} gpus"
)


def get_tensor_parallel_device_mesh(
rank: int = 0, world_size: int = 1
) -> tuple[DeviceMesh, int, int]:
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank


def initialize_distributed_logger(rank: int, logger_file_name: str) -> logging.Logger:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger
21 changes: 13 additions & 8 deletions examples/distributed_inference/tensor_parallel_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,31 @@

"""

import logging
import os
import time

import torch
import torch_tensorrt
from rotary_embedding import RotaryAttention, parallel_rotary_block
import torch.distributed as dist
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_rotary_embedding"
)
if not dist.is_initialized():
initialize_distributed_env()

import torch_tensorrt

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_rotary_embedding")

from rotary_embedding import RotaryAttention, parallel_rotary_block

"""
This example covers the rotary embedding in Llama3 model and is derived from https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
Command to run with single GPU: mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
Command to run with single GPU: USE_TRTLLM_PLUGINS=1 mpirun -n 1 --allow-run-as-root python tensor_parallel_rotary_embedding.py
Command to run with 2 GPUs: USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_rotary_embedding.py
"""

BATCH = 2
Expand Down
19 changes: 14 additions & 5 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
-----
.. code-block:: bash

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
USE_TRTLLM_PLUGINS=1 mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
"""

import time
Expand All @@ -25,22 +25,31 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
get_tensor_parallel_device_mesh,
initialize_distributed_env,
initialize_distributed_logger,
)

if not dist.is_initialized():
initialize_distributed_env()
import torch_tensorrt
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_simple_example")


"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand Down
Loading
Loading