From 4ea06bfbc75c5d1274773b5950bdf802b0856d4b Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 15 Dec 2025 09:14:34 -0300 Subject: [PATCH 1/5] Created shard.ex --- nx/lib/nx/defn/shard.ex | 59 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 nx/lib/nx/defn/shard.ex diff --git a/nx/lib/nx/defn/shard.ex b/nx/lib/nx/defn/shard.ex new file mode 100644 index 0000000000..0606c8a786 --- /dev/null +++ b/nx/lib/nx/defn/shard.ex @@ -0,0 +1,59 @@ +defmodule Nx.Defn.Shard do + + @callback __shard_jit__( + key :: term, + vars, + fun :: (vars -> Nx.Container.t()), + args_list :: [[(-> Nx.Tensor.t())]], + opts :: keyword + ) :: [Nx.Container.t()] + when vars: [Nx.Container.t()] + + def __shard_jit__(fun, params, args_list, opts) do + {compiler, runtime_fun, opts} = prepare_options(fun, opts) + compiler.__shard_jit__(fun, params, runtime_fun, args_list, opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__shard_jit__, 5, __STACKTRACE__) + end + + def shard_jit(fun, opts \\ []) when is_function(fun) and is_list(opts) do + wrap(fun, &shard_jit_apply(fun, &1, opts)) + end + + + defp raise_missing_callback(exception, name, arity, stacktrace) do + case exception do + %UndefinedFunctionError{module: compiler, function: ^name, arity: ^arity} -> + raise ArgumentError, + "the expected compiler callback #{name}/#{arity} is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." + + _ -> + # This is not an error that should've been caught by this function, so we pass the exception along + reraise exception, stacktrace + end + end + + defp prepare_options(fun, opts) do + {compiler, opts} = Keyword.pop(opts, :compiler, Nx.Defn.Evaluator) + {compiler, &runtime_fun(&1, fun, compiler), opts} + end + + + + + + + def jit(fun, mesh, opts \\ []) do + ... + end + + def jit_apply(fun, mesh, args, opts \\ []) do + jit(fun, mesh, opts).(args) + end +end + +defmodule Nx.Defn.Shard.Mesh do + defstruct [:name, :shape] + @type t :: %__MODULE__{name: String.t(), shape: tuple()} +end From 390997d0af407ccda69034a0427174db5feb13a6 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Mon, 15 Dec 2025 09:17:11 -0300 Subject: [PATCH 2/5] fixed some arity issues --- nx/lib/nx/defn/shard.ex | 79 +++++++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/nx/lib/nx/defn/shard.ex b/nx/lib/nx/defn/shard.ex index 0606c8a786..b3e229bb72 100644 --- a/nx/lib/nx/defn/shard.ex +++ b/nx/lib/nx/defn/shard.ex @@ -1,7 +1,11 @@ defmodule Nx.Defn.Shard do + # Exemplo original + # @callback __shard_jit__(fun, mesh, opts) + @callback __shard_jit__( key :: term, + mesh :: Nx.Defn.Shard.Mesh.t(), vars, fun :: (vars -> Nx.Container.t()), args_list :: [[(-> Nx.Tensor.t())]], @@ -9,24 +13,43 @@ defmodule Nx.Defn.Shard do ) :: [Nx.Container.t()] when vars: [Nx.Container.t()] - def __shard_jit__(fun, params, args_list, opts) do - {compiler, runtime_fun, opts} = prepare_options(fun, opts) - compiler.__shard_jit__(fun, params, runtime_fun, args_list, opts) + def __shard_jit__(fun, mesh, params, args_list, opts) do + {module, runtime_fun, opts} = prepare_options(fun, mesh, opts) + module.__shard_jit__(fun, mesh, params, runtime_fun, args_list, opts) rescue e in [UndefinedFunctionError] -> - raise_missing_callback(e, :__shard_jit__, 5, __STACKTRACE__) + raise_missing_callback(e, :__shard_jit__, 6, __STACKTRACE__) + end + + def shard_jit(fun, mesh, opts \\ []) when is_function(fun) and is_list(opts) do + wrap(fun, &shard_jit_apply(fun, mesh, &1, opts)) end - def shard_jit(fun, opts \\ []) when is_function(fun) and is_list(opts) do - wrap(fun, &shard_jit_apply(fun, &1, opts)) + def shard_jit_apply(fun, mesh, args, opts \\ []) + when is_function(fun) and is_list(args) and is_list(opts) do + {on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise) + + cond do + Nx.Defn.Shard.current() == nil -> + do_shard_jit_apply(fun, mesh, args, opts) + + on_conflict == :raise -> + raise "cannot invoke Shard JITed function when there is a Shard JIT compilation happening" + + on_conflict == :force -> + do_shard_jit_apply(fun, mesh, args, opts) + + on_conflict == :reuse -> + apply(fun, args) + end end defp raise_missing_callback(exception, name, arity, stacktrace) do case exception do - %UndefinedFunctionError{module: compiler, function: ^name, arity: ^arity} -> + %UndefinedFunctionError{module: module, function: ^name, arity: ^arity} -> raise ArgumentError, - "the expected compiler callback #{name}/#{arity} is missing. Please check that the module #{inspect(compiler)} is an Nx.Defn.Compiler." + "the expected shard callback #{name}/#{arity} is missing. Please check that the module #{inspect(module)} is an Nx.Defn.Shard." _ -> # This is not an error that should've been caught by this function, so we pass the exception along @@ -34,26 +57,44 @@ defmodule Nx.Defn.Shard do end end - defp prepare_options(fun, opts) do - {compiler, opts} = Keyword.pop(opts, :compiler, Nx.Defn.Evaluator) - {compiler, &runtime_fun(&1, fun, compiler), opts} + defp prepare_options(fun, mesh, opts) do + {module, opts} = Keyword.pop(opts, :module, Nx.Defn.Evaluator) + {module, &runtime_fun(&1, fun, mesh, module), opts} end + defp runtime_fun(args, fun, mesh, module) do + previous_backend = Process.put(Nx.Shared.backend_pdict_key(), {Nx.Defn.Expr, []}) + previous = Process.put(Nx.Defn.Shard, module) + try do + fun + |> apply(args) + |> Nx.Defn.Composite.traverse(&Nx.Defn.Expr.tensor/1) + after + if previous_backend do + Process.put(Nx.Shared.backend_pdict_key(), previous_backend) + else + Process.delete(Nx.Shared.backend_pdict_key()) + end - - - - def jit(fun, mesh, opts \\ []) do - ... + if previous do + Process.put(Nx.Defn.Shard, module) + else + Process.delete(Nx.Defn.Shard) + end + end end - def jit_apply(fun, mesh, args, opts \\ []) do - jit(fun, mesh, opts).(args) - end + defp do_shard_jit_apply(fun, mesh, args, opts) do + opts = prepare_options(opts) + {fun, params, _templates, flatten} = Nx.Defn.Compiler.to_lazy_params(fun, args) + [res] = Nx.Defn.Shard.__shard_jit__(fun, mesh, params, [flatten], opts) + res end defmodule Nx.Defn.Shard.Mesh do defstruct [:name, :shape] @type t :: %__MODULE__{name: String.t(), shape: tuple()} end + +end From 0229c93bb778267929dca6e0f5448630632e2ff1 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:36:54 -0300 Subject: [PATCH 3/5] updated sharding implementation to be inside the existing compiler and defn files. only the Mesh module has a standalone file now --- nx/lib/nx/defn.ex | 23 ++++++++ nx/lib/nx/defn/compiler.ex | 27 ++++++++++ nx/lib/nx/defn/shard.ex | 100 ----------------------------------- nx/lib/nx/defn/shard/mesh.ex | 4 ++ 4 files changed, 54 insertions(+), 100 deletions(-) delete mode 100644 nx/lib/nx/defn/shard.ex create mode 100644 nx/lib/nx/defn/shard/mesh.ex diff --git a/nx/lib/nx/defn.ex b/nx/lib/nx/defn.ex index e798fb31ab..a3c50bb35e 100644 --- a/nx/lib/nx/defn.ex +++ b/nx/lib/nx/defn.ex @@ -849,6 +849,29 @@ defmodule Nx.Defn do :ok end + def shard_jit(fun, mesh, opts \\ []) when is_function(fun) and is_list(opts) do + wrap(fun, &shard_jit_apply(fun, mesh, &1, opts)) + end + + def shard_jit_apply(fun, mesh, args, opts \\ []) + when is_function(fun) and is_list(args) and is_list(opts) do + {on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise) + + cond do + Nx.Defn.current() == nil -> + do_shard_jit_apply(fun, mesh, args, opts) + + on_conflict == :raise -> + raise "cannot invoke Shard JITed function when there is a Shard JIT compilation happening" + + on_conflict == :force -> + do_shard_jit_apply(fun, mesh, args, opts) + + on_conflict == :reuse -> + apply(fun, args) + end +end + defp compile_error!(env, description) do raise CompileError, line: env.line, file: env.file, description: description end diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index c25e691470..b283782100 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -73,6 +73,25 @@ defmodule Nx.Defn.Compiler do """ @callback __to_backend__(keyword) :: {module, keyword} + @doc """ + Callback for compilation. + + Its main purpose is to compile a function for a given mesh. + """ + @callback __shard_jit__( + key :: term, + mesh :: Nx.Defn.Shard.Mesh.t(), + [vars], + fun :: (vars -> Nx.Container.t()), + args_list :: [[(-> Nx.Tensor.t())]], + opts :: keyword + ) :: [Nx.Container.t()] + when vars: [Nx.Container.t()] + + @optional_callbacks [ + __shard_jit__: 6 + ] + # Modules allowed in defn @allowed_modules [Nx.Constants, Nx.Defn, Nx.Defn.Kernel, Nx.LinAlg, Nx.Type] @@ -265,6 +284,14 @@ defmodule Nx.Defn.Compiler do {:__block__, [], quoted} end + def __shard_jit__(fun, mesh, params, args_list, opts) do + {module, runtime_fun, opts} = prepare_options(fun, mesh, opts) + module.__shard_jit__(fun, mesh, params, runtime_fun, args_list, opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__shard_jit__, 6, __STACKTRACE__) + end + defp compile_prepare_arities(definitions) do for {{name, arity}, %{defaults: defaults}} <- definitions, arity <- (arity - map_size(defaults))..arity, diff --git a/nx/lib/nx/defn/shard.ex b/nx/lib/nx/defn/shard.ex deleted file mode 100644 index b3e229bb72..0000000000 --- a/nx/lib/nx/defn/shard.ex +++ /dev/null @@ -1,100 +0,0 @@ -defmodule Nx.Defn.Shard do - - # Exemplo original - # @callback __shard_jit__(fun, mesh, opts) - - @callback __shard_jit__( - key :: term, - mesh :: Nx.Defn.Shard.Mesh.t(), - vars, - fun :: (vars -> Nx.Container.t()), - args_list :: [[(-> Nx.Tensor.t())]], - opts :: keyword - ) :: [Nx.Container.t()] - when vars: [Nx.Container.t()] - - def __shard_jit__(fun, mesh, params, args_list, opts) do - {module, runtime_fun, opts} = prepare_options(fun, mesh, opts) - module.__shard_jit__(fun, mesh, params, runtime_fun, args_list, opts) - rescue - e in [UndefinedFunctionError] -> - raise_missing_callback(e, :__shard_jit__, 6, __STACKTRACE__) - end - - def shard_jit(fun, mesh, opts \\ []) when is_function(fun) and is_list(opts) do - wrap(fun, &shard_jit_apply(fun, mesh, &1, opts)) - end - - def shard_jit_apply(fun, mesh, args, opts \\ []) - when is_function(fun) and is_list(args) and is_list(opts) do - {on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise) - - cond do - Nx.Defn.Shard.current() == nil -> - do_shard_jit_apply(fun, mesh, args, opts) - - on_conflict == :raise -> - raise "cannot invoke Shard JITed function when there is a Shard JIT compilation happening" - - on_conflict == :force -> - do_shard_jit_apply(fun, mesh, args, opts) - - on_conflict == :reuse -> - apply(fun, args) - end - end - - - defp raise_missing_callback(exception, name, arity, stacktrace) do - case exception do - %UndefinedFunctionError{module: module, function: ^name, arity: ^arity} -> - raise ArgumentError, - "the expected shard callback #{name}/#{arity} is missing. Please check that the module #{inspect(module)} is an Nx.Defn.Shard." - - _ -> - # This is not an error that should've been caught by this function, so we pass the exception along - reraise exception, stacktrace - end - end - - defp prepare_options(fun, mesh, opts) do - {module, opts} = Keyword.pop(opts, :module, Nx.Defn.Evaluator) - {module, &runtime_fun(&1, fun, mesh, module), opts} - end - - defp runtime_fun(args, fun, mesh, module) do - previous_backend = Process.put(Nx.Shared.backend_pdict_key(), {Nx.Defn.Expr, []}) - previous = Process.put(Nx.Defn.Shard, module) - - try do - fun - |> apply(args) - |> Nx.Defn.Composite.traverse(&Nx.Defn.Expr.tensor/1) - after - if previous_backend do - Process.put(Nx.Shared.backend_pdict_key(), previous_backend) - else - Process.delete(Nx.Shared.backend_pdict_key()) - end - - if previous do - Process.put(Nx.Defn.Shard, module) - else - Process.delete(Nx.Defn.Shard) - end - end - end - - defp do_shard_jit_apply(fun, mesh, args, opts) do - opts = prepare_options(opts) - {fun, params, _templates, flatten} = Nx.Defn.Compiler.to_lazy_params(fun, args) - [res] = Nx.Defn.Shard.__shard_jit__(fun, mesh, params, [flatten], opts) - res -end - -defmodule Nx.Defn.Shard.Mesh do - defstruct [:name, :shape] - @type t :: %__MODULE__{name: String.t(), shape: tuple()} -end - -end diff --git a/nx/lib/nx/defn/shard/mesh.ex b/nx/lib/nx/defn/shard/mesh.ex new file mode 100644 index 0000000000..2588fc4ae9 --- /dev/null +++ b/nx/lib/nx/defn/shard/mesh.ex @@ -0,0 +1,4 @@ +defmodule Nx.Defn.Shard.Mesh do + defstruct [:name, :shape] + @type t :: %__MODULE__{name: String.t(), shape: tuple()} +end From 3283d19a7088918509d07c2dd9dd14887399bb17 Mon Sep 17 00:00:00 2001 From: Chapaman <14204271+Chapaman@users.noreply.github.com> Date: Wed, 17 Dec 2025 20:04:52 -0300 Subject: [PATCH 4/5] Expanded on the documentation as suggested by @polvalente --- nx/lib/nx/defn/compiler.ex | 12 ++++++++++-- nx/lib/nx/defn/shard/mesh.ex | 14 +++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index b283782100..078feabea6 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -74,9 +74,17 @@ defmodule Nx.Defn.Compiler do @callback __to_backend__(keyword) :: {module, keyword} @doc """ - Callback for compilation. + Callback for compilation of a parallelizable computation. + + Its main purpose is to compile a function for a given `Nx.Defn.Shard.Mesh`. + + Receives an opaque `key` used for caching, a `mesh`, a list of `vars` + in `[vars]`, the function `fun` which builds a defn expression, a list of + argument lists in `args_list`, and the compiler options. - Its main purpose is to compile a function for a given mesh. + Using `[vars]` instead of a single `vars` allows the compiler to keep one + set of abstract parameters per shard or logical device in the mesh. This is useful + when the tensors are already divided into shards. """ @callback __shard_jit__( key :: term, diff --git a/nx/lib/nx/defn/shard/mesh.ex b/nx/lib/nx/defn/shard/mesh.ex index 2588fc4ae9..c84fb59bf1 100644 --- a/nx/lib/nx/defn/shard/mesh.ex +++ b/nx/lib/nx/defn/shard/mesh.ex @@ -1,4 +1,16 @@ -defmodule Nx.Defn.Shard.Mesh do + defmodule Nx.Defn.Shard.Mesh do + @moduledoc """ + A mesh is a named collection of devices arranged in a logical shape. + + `name` is a string identifier for the mesh in the lowered program so that + sharding annotations can refer to a specific device topology without + embedding concrete device handles directly in the intermediate + representation. + + `shape` is a tuple describing the logical layout of devices, where each + element is the size of a mesh dimension. For instance, a shape like + `{2, 4}` represents a 2x4 logical grid of devices. + """ defstruct [:name, :shape] @type t :: %__MODULE__{name: String.t(), shape: tuple()} end From 86fe7b4fda2985c1d7e1ef22d55524d18c6b7fbf Mon Sep 17 00:00:00 2001 From: Arthur Romeu <14204271+Chapaman@users.noreply.github.com> Date: Fri, 19 Dec 2025 18:10:14 -0300 Subject: [PATCH 5/5] Update nx/lib/nx/defn/shard/mesh.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- nx/lib/nx/defn/shard/mesh.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx/lib/nx/defn/shard/mesh.ex b/nx/lib/nx/defn/shard/mesh.ex index c84fb59bf1..71a97aff33 100644 --- a/nx/lib/nx/defn/shard/mesh.ex +++ b/nx/lib/nx/defn/shard/mesh.ex @@ -1,4 +1,4 @@ - defmodule Nx.Defn.Shard.Mesh do + defmodule Nx.Defn.Mesh do @moduledoc """ A mesh is a named collection of devices arranged in a logical shape.