Skip to content

Commit 0229c93

Browse files
committed
updated sharding implementation to be inside the existing compiler and defn files. only the Mesh module has a standalone file now
1 parent 390997d commit 0229c93

File tree

4 files changed

+54
-100
lines changed

4 files changed

+54
-100
lines changed

nx/lib/nx/defn.ex

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,29 @@ defmodule Nx.Defn do
849849
:ok
850850
end
851851

852+
def shard_jit(fun, mesh, opts \\ []) when is_function(fun) and is_list(opts) do
853+
wrap(fun, &shard_jit_apply(fun, mesh, &1, opts))
854+
end
855+
856+
def shard_jit_apply(fun, mesh, args, opts \\ [])
857+
when is_function(fun) and is_list(args) and is_list(opts) do
858+
{on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise)
859+
860+
cond do
861+
Nx.Defn.current() == nil ->
862+
do_shard_jit_apply(fun, mesh, args, opts)
863+
864+
on_conflict == :raise ->
865+
raise "cannot invoke Shard JITed function when there is a Shard JIT compilation happening"
866+
867+
on_conflict == :force ->
868+
do_shard_jit_apply(fun, mesh, args, opts)
869+
870+
on_conflict == :reuse ->
871+
apply(fun, args)
872+
end
873+
end
874+
852875
defp compile_error!(env, description) do
853876
raise CompileError, line: env.line, file: env.file, description: description
854877
end

nx/lib/nx/defn/compiler.ex

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,25 @@ defmodule Nx.Defn.Compiler do
7373
"""
7474
@callback __to_backend__(keyword) :: {module, keyword}
7575

76+
@doc """
77+
Callback for compilation.
78+
79+
Its main purpose is to compile a function for a given mesh.
80+
"""
81+
@callback __shard_jit__(
82+
key :: term,
83+
mesh :: Nx.Defn.Shard.Mesh.t(),
84+
[vars],
85+
fun :: (vars -> Nx.Container.t()),
86+
args_list :: [[(-> Nx.Tensor.t())]],
87+
opts :: keyword
88+
) :: [Nx.Container.t()]
89+
when vars: [Nx.Container.t()]
90+
91+
@optional_callbacks [
92+
__shard_jit__: 6
93+
]
94+
7695
# Modules allowed in defn
7796
@allowed_modules [Nx.Constants, Nx.Defn, Nx.Defn.Kernel, Nx.LinAlg, Nx.Type]
7897

@@ -265,6 +284,14 @@ defmodule Nx.Defn.Compiler do
265284
{:__block__, [], quoted}
266285
end
267286

287+
def __shard_jit__(fun, mesh, params, args_list, opts) do
288+
{module, runtime_fun, opts} = prepare_options(fun, mesh, opts)
289+
module.__shard_jit__(fun, mesh, params, runtime_fun, args_list, opts)
290+
rescue
291+
e in [UndefinedFunctionError] ->
292+
raise_missing_callback(e, :__shard_jit__, 6, __STACKTRACE__)
293+
end
294+
268295
defp compile_prepare_arities(definitions) do
269296
for {{name, arity}, %{defaults: defaults}} <- definitions,
270297
arity <- (arity - map_size(defaults))..arity,

nx/lib/nx/defn/shard.ex

Lines changed: 0 additions & 100 deletions
This file was deleted.

nx/lib/nx/defn/shard/mesh.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
defmodule Nx.Defn.Shard.Mesh do
2+
defstruct [:name, :shape]
3+
@type t :: %__MODULE__{name: String.t(), shape: tuple()}
4+
end

0 commit comments

Comments
 (0)