Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions nx/lib/nx/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,29 @@ defmodule Nx.Defn do
:ok
end

def shard_jit(fun, mesh, opts \\ []) when is_function(fun) and is_list(opts) do
wrap(fun, &shard_jit_apply(fun, mesh, &1, opts))
end

def shard_jit_apply(fun, mesh, args, opts \\ [])
when is_function(fun) and is_list(args) and is_list(opts) do
{on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise)

cond do
Nx.Defn.current() == nil ->
do_shard_jit_apply(fun, mesh, args, opts)

on_conflict == :raise ->
raise "cannot invoke Shard JITed function when there is a Shard JIT compilation happening"

on_conflict == :force ->
do_shard_jit_apply(fun, mesh, args, opts)

on_conflict == :reuse ->
apply(fun, args)
end
end

defp compile_error!(env, description) do
raise CompileError, line: env.line, file: env.file, description: description
end
Expand Down
35 changes: 35 additions & 0 deletions nx/lib/nx/defn/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,33 @@ defmodule Nx.Defn.Compiler do
"""
@callback __to_backend__(keyword) :: {module, keyword}

@doc """
Callback for compilation of a parallelizable computation.

Its main purpose is to compile a function for a given `Nx.Defn.Shard.Mesh`.

Receives an opaque `key` used for caching, a `mesh`, a list of `vars`
in `[vars]`, the function `fun` which builds a defn expression, a list of
argument lists in `args_list`, and the compiler options.

Using `[vars]` instead of a single `vars` allows the compiler to keep one
set of abstract parameters per shard or logical device in the mesh. This is useful
when the tensors are already divided into shards.
"""
@callback __shard_jit__(
key :: term,
mesh :: Nx.Defn.Shard.Mesh.t(),
[vars],
fun :: (vars -> Nx.Container.t()),
args_list :: [[(-> Nx.Tensor.t())]],
opts :: keyword
) :: [Nx.Container.t()]
when vars: [Nx.Container.t()]

@optional_callbacks [
__shard_jit__: 6
]

# Modules allowed in defn
@allowed_modules [Nx.Constants, Nx.Defn, Nx.Defn.Kernel, Nx.LinAlg, Nx.Type]

Expand Down Expand Up @@ -265,6 +292,14 @@ defmodule Nx.Defn.Compiler do
{:__block__, [], quoted}
end

def __shard_jit__(fun, mesh, params, args_list, opts) do
{module, runtime_fun, opts} = prepare_options(fun, mesh, opts)
module.__shard_jit__(fun, mesh, params, runtime_fun, args_list, opts)
rescue
e in [UndefinedFunctionError] ->
raise_missing_callback(e, :__shard_jit__, 6, __STACKTRACE__)
end

defp compile_prepare_arities(definitions) do
for {{name, arity}, %{defaults: defaults}} <- definitions,
arity <- (arity - map_size(defaults))..arity,
Expand Down
16 changes: 16 additions & 0 deletions nx/lib/nx/defn/shard/mesh.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
defmodule Nx.Defn.Shard.Mesh do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
defmodule Nx.Defn.Shard.Mesh do
defmodule Nx.Defn.Mesh do

Let's avoid too deep nesting. So either Nx.Defn.Mesh or, if we feel it will be used outside of defn, Nx.Mesh.

@moduledoc """
A mesh is a named collection of devices arranged in a logical shape.
`name` is a string identifier for the mesh in the lowered program so that
sharding annotations can refer to a specific device topology without
embedding concrete device handles directly in the intermediate
representation.
`shape` is a tuple describing the logical layout of devices, where each
element is the size of a mesh dimension. For instance, a shape like
`{2, 4}` represents a 2x4 logical grid of devices.
"""
defstruct [:name, :shape]
@type t :: %__MODULE__{name: String.t(), shape: tuple()}
end