diff --git a/nx/lib/nx/defn.ex b/nx/lib/nx/defn.ex index e798fb31ab..a3c50bb35e 100644 --- a/nx/lib/nx/defn.ex +++ b/nx/lib/nx/defn.ex @@ -849,6 +849,29 @@ defmodule Nx.Defn do :ok end + def shard_jit(fun, mesh, opts \\ []) when is_function(fun) and is_list(opts) do + wrap(fun, &shard_jit_apply(fun, mesh, &1, opts)) + end + + def shard_jit_apply(fun, mesh, args, opts \\ []) + when is_function(fun) and is_list(args) and is_list(opts) do + {on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise) + + cond do + Nx.Defn.current() == nil -> + do_shard_jit_apply(fun, mesh, args, opts) + + on_conflict == :raise -> + raise "cannot invoke Shard JITed function when there is a Shard JIT compilation happening" + + on_conflict == :force -> + do_shard_jit_apply(fun, mesh, args, opts) + + on_conflict == :reuse -> + apply(fun, args) + end +end + defp compile_error!(env, description) do raise CompileError, line: env.line, file: env.file, description: description end diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index c25e691470..078feabea6 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -73,6 +73,33 @@ defmodule Nx.Defn.Compiler do """ @callback __to_backend__(keyword) :: {module, keyword} + @doc """ + Callback for compilation of a parallelizable computation. + + Its main purpose is to compile a function for a given `Nx.Defn.Shard.Mesh`. + + Receives an opaque `key` used for caching, a `mesh`, a list of `vars` + in `[vars]`, the function `fun` which builds a defn expression, a list of + argument lists in `args_list`, and the compiler options. + + Using `[vars]` instead of a single `vars` allows the compiler to keep one + set of abstract parameters per shard or logical device in the mesh. This is useful + when the tensors are already divided into shards. + """ + @callback __shard_jit__( + key :: term, + mesh :: Nx.Defn.Shard.Mesh.t(), + [vars], + fun :: (vars -> Nx.Container.t()), + args_list :: [[(-> Nx.Tensor.t())]], + opts :: keyword + ) :: [Nx.Container.t()] + when vars: [Nx.Container.t()] + + @optional_callbacks [ + __shard_jit__: 6 + ] + # Modules allowed in defn @allowed_modules [Nx.Constants, Nx.Defn, Nx.Defn.Kernel, Nx.LinAlg, Nx.Type] @@ -265,6 +292,14 @@ defmodule Nx.Defn.Compiler do {:__block__, [], quoted} end + def __shard_jit__(fun, mesh, params, args_list, opts) do + {module, runtime_fun, opts} = prepare_options(fun, mesh, opts) + module.__shard_jit__(fun, mesh, params, runtime_fun, args_list, opts) + rescue + e in [UndefinedFunctionError] -> + raise_missing_callback(e, :__shard_jit__, 6, __STACKTRACE__) + end + defp compile_prepare_arities(definitions) do for {{name, arity}, %{defaults: defaults}} <- definitions, arity <- (arity - map_size(defaults))..arity, diff --git a/nx/lib/nx/defn/shard/mesh.ex b/nx/lib/nx/defn/shard/mesh.ex new file mode 100644 index 0000000000..71a97aff33 --- /dev/null +++ b/nx/lib/nx/defn/shard/mesh.ex @@ -0,0 +1,16 @@ + defmodule Nx.Defn.Mesh do + @moduledoc """ + A mesh is a named collection of devices arranged in a logical shape. + + `name` is a string identifier for the mesh in the lowered program so that + sharding annotations can refer to a specific device topology without + embedding concrete device handles directly in the intermediate + representation. + + `shape` is a tuple describing the logical layout of devices, where each + element is the size of a mesh dimension. For instance, a shape like + `{2, 4}` represents a 2x4 logical grid of devices. + """ + defstruct [:name, :shape] + @type t :: %__MODULE__{name: String.t(), shape: tuple()} +end