Skip to content

Commit b0263fb

Browse files
XPUBackend refactoring to facilitate arch-specific implementations
- Allow metaclasses injection: use XPUBackendMeta from the .meta submodule as the metaclass, if it exists. - Try to find an arch-specific implementation in the .arch.<name> submodule. - Create the list of passes in separate methods to allow subclasses to modify it.
1 parent 0794e64 commit b0263fb

File tree

1 file changed

+91
-63
lines changed

1 file changed

+91
-63
lines changed

third_party/intel/backend/compiler.py

Lines changed: 91 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from triton.backends.compiler import BaseBackend, Language
1+
from triton.backends.compiler import BaseBackend, GPUTarget, Language
22
from triton._C.libtriton import ir, passes, llvm, intel
33
from triton.backends.intel.driver import compile_module_from_src
44
from triton.backends.intel.track import track
@@ -15,6 +15,11 @@
1515
import subprocess
1616
from pathlib import Path
1717

18+
try: # XPUBackend allows metaclasses injection
19+
from .meta import XPUBackendMeta
20+
except ImportError:
21+
XPUBackendMeta = type(BaseBackend)
22+
1823

1924
@dataclass
2025
class XPUOptions:
@@ -63,40 +68,41 @@ def hash(self):
6368
return hashlib.sha256(key.encode("utf-8")).hexdigest()
6469

6570

66-
def min_dot_size(device_props: dict):
67-
# (M, N, K)
68-
# M: repeatCount. 1,2,4,8
69-
# N: executionSize. 16 for PVC, 8 for ATS
70-
# K: systolicDepth x opsPerChan. systolicDepth must be 8
71-
repeat_count = 1
72-
sdepth = 8
73-
exec_size = min(device_props["sub_group_sizes"])
74-
75-
def get_ops_per_channel(lhs_type, rhs_type):
76-
l_bitwidth = lhs_type.scalar.primitive_bitwidth
77-
r_bitwidth = rhs_type.scalar.primitive_bitwidth
78-
max_ops_per_chan = 32 / max(l_bitwidth, r_bitwidth)
79-
return min(8, max_ops_per_chan)
80-
81-
return lambda lhs_type, rhs_type: (repeat_count, exec_size, sdepth * get_ops_per_channel(lhs_type, rhs_type))
82-
83-
84-
class XPUBackend(BaseBackend):
71+
class XPUBackend(BaseBackend, metaclass=XPUBackendMeta):
72+
arch_to_impl = {} # Architecture id to backend implementation class mapping
73+
binary_ext = "spv"
74+
target_arch = "spir64"
8575
instrumentation = None
8676

8777
@staticmethod
88-
def supports_target(target: tuple):
78+
def supports_target(target: GPUTarget):
8979
return target.backend == 'xpu'
9080

91-
def __init__(self, target: tuple) -> None:
92-
super().__init__(target)
81+
def __new__(cls, target: GPUTarget):
9382
if not isinstance(target.arch, dict):
9483
raise TypeError("target.arch is not a dict")
95-
dirname = os.path.dirname(os.path.realpath(__file__))
96-
mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils")
97-
self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0))
84+
if cls is not XPUBackend:
85+
return super().__new__(cls)
86+
arch = target.arch.get("architecture", 0)
87+
if (impl := cls.arch_to_impl.get(arch, None)) is None:
88+
# Try to find an arch-specific implementation in the .arch.<name> submodule.
89+
if not (dev_arch := knobs.intel.device_arch):
90+
dirname = os.path.dirname(os.path.realpath(__file__))
91+
parser = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(),
92+
name="arch_utils")
93+
dev_arch = parser.parse_device_arch(target.arch.get('architecture', 0))
94+
mod_name = f"{__package__}.arch.{dev_arch}"
95+
try:
96+
impl = __import__(mod_name, fromlist=["XPUBackendImpl"]).XPUBackendImpl
97+
except ImportError:
98+
impl = type(f"{mod_name}.XPUBackendImpl", (cls, ), {})
99+
impl.device_arch = dev_arch
100+
cls.arch_to_impl[arch] = impl
101+
return super().__new__(impl)
102+
103+
def __init__(self, target: GPUTarget) -> None:
104+
super().__init__(target)
98105
self.properties = self.parse_target(target.arch)
99-
self.binary_ext = "spv"
100106

101107
def get_target_name(self, options) -> str:
102108
return f"xpu:{self.device_arch}"
@@ -123,18 +129,36 @@ def parse_target(self, tgt_prop) -> dict:
123129
return dev_prop
124130

125131
def parse_options(self, opts) -> Any:
126-
args = {k: opts[k] for k in XPUOptions.__dataclass_fields__.keys() if k in opts}
132+
args = {k: v for k, v in opts.items() if k in XPUOptions.__dataclass_fields__}
127133
args["allow_fp8e4nv"] = True
128134
return XPUOptions(**args)
129135

130136
def pack_metadata(self, metadata):
131137
return metadata
132138

139+
@staticmethod
140+
def min_dot_size(device_props: dict):
141+
# (M, N, K)
142+
# M: repeatCount. 1,2,4,8
143+
# N: executionSize. 16 for PVC, 8 for ATS
144+
# K: systolicDepth x opsPerChan. systolicDepth must be 8
145+
repeat_count = 1
146+
sdepth = 8
147+
exec_size = min(device_props["sub_group_sizes"])
148+
149+
def get_ops_per_channel(lhs_type, rhs_type):
150+
l_bitwidth = lhs_type.scalar.primitive_bitwidth
151+
r_bitwidth = rhs_type.scalar.primitive_bitwidth
152+
max_ops_per_chan = 32 / max(l_bitwidth, r_bitwidth)
153+
return min(8, max_ops_per_chan)
154+
155+
return lambda lhs_type, rhs_type: (repeat_count, exec_size, sdepth * get_ops_per_channel(lhs_type, rhs_type))
156+
133157
def get_codegen_implementation(self, options):
134158
from triton.language.extra.intel import convert_custom_float8
135159
codegen_fns = {}
136160
codegen_fns["convert_custom_types"] = convert_custom_float8
137-
codegen_fns["min_dot_size"] = min_dot_size(self.properties)
161+
codegen_fns["min_dot_size"] = self.min_dot_size(self.properties)
138162
return codegen_fns
139163

140164
def get_module_map(self) -> Dict[str, ModuleType]:
@@ -143,8 +167,8 @@ def get_module_map(self) -> Dict[str, ModuleType]:
143167

144168
def load_dialects(self, ctx):
145169
intel.load_dialects(ctx)
146-
if XPUBackend.instrumentation:
147-
XPUBackend.instrumentation.load_dialects(ctx)
170+
if self.instrumentation:
171+
self.instrumentation.load_dialects(ctx)
148172

149173
@staticmethod
150174
def validate_options(opt, properties):
@@ -158,20 +182,15 @@ def validate_options(opt, properties):
158182
f"num_warps={opt.num_warps} is unsupported for the target (limit is {properties['max_num_sub_groups']})"
159183
)
160184

161-
@staticmethod
162-
def annotate_module(mod, properties, opt, target_arch):
185+
@classmethod
186+
def annotate_module(cls, module_opts, properties, opt):
163187
# Annotate module with information required by subsequent transformations.
164-
pm = ir.pass_manager(mod.context)
165-
pm.enable_debug()
166-
module_opts = intel.passes.ttgpuir.AnnotateModuleOptions()
167188
module_opts.min_sg_size = min(properties["sub_group_sizes"])
168189
module_opts.support_sg_2d_block = properties["has_subgroup_2d_block_io"]
169190
module_opts.support_dpas = properties["has_subgroup_matrix_multiply_accumulate"]
170191
module_opts.support_bf16_conversion = properties["has_bfloat16_conversions"]
171192
module_opts.threads_per_warp = opt.warp_size
172-
module_opts.target_arch = target_arch
173-
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
174-
pm.run(mod, 'annotate_module')
193+
module_opts.target_arch = cls.target_arch
175194

176195
@staticmethod
177196
def get_split_barrier_scope(opt):
@@ -182,9 +201,9 @@ def get_split_barrier_scope(opt):
182201
split_barriers_scope = intel.SplitBarrierScope.Subgroup
183202
return split_barriers_scope
184203

185-
@staticmethod
204+
@classmethod
186205
@track
187-
def make_ttir(mod, metadata, opt):
206+
def make_ttir(cls, mod, metadata, opt):
188207
pm = ir.pass_manager(mod.context)
189208
pm.enable_debug()
190209
passes.common.add_inliner(pm)
@@ -204,21 +223,26 @@ def make_ttir(mod, metadata, opt):
204223
pm.run(mod, 'make_ttir')
205224
return mod
206225

207-
@staticmethod
226+
@classmethod
208227
@track
209-
def make_ttgir(mod, metadata, opt, properties):
228+
def make_ttgir(cls, mod, metadata, opt, properties):
210229
cluster_info = intel.ClusterInfo()
211230
if opt.cluster_dims is not None:
212231
cluster_info.clusterDimX = opt.cluster_dims[0]
213232
cluster_info.clusterDimY = opt.cluster_dims[1]
214233
cluster_info.clusterDimZ = opt.cluster_dims[2]
215234

216235
# Annotate module with information required by subsequent transformations.
217-
XPUBackend.annotate_module(mod, properties, opt, "spir64")
236+
pm = ir.pass_manager(mod.context)
237+
pm.enable_debug()
238+
module_opts = intel.passes.ttgpuir.AnnotateModuleOptions()
239+
cls.annotate_module(module_opts, properties, opt)
240+
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
241+
pm.run(mod, 'annotate_module')
218242

219243
# Overwrite the warp_size option with the module annotation.
220244
opt.warp_size = intel.get_threads_per_warp(mod)
221-
XPUBackend.validate_options(opt, properties)
245+
cls.validate_options(opt, properties)
222246

223247
pm = ir.pass_manager(mod.context)
224248
pm.enable_debug()
@@ -278,9 +302,15 @@ def gluon_to_ttgir(self, src, metadata, options):
278302
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
279303
return mod
280304

281-
@staticmethod
305+
@classmethod
306+
def optimize_llvm_mod(cls, llvm_mod, options):
307+
intel.set_spv_target_triple(llvm_mod)
308+
with track("optimize_module") as tr:
309+
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))
310+
311+
@classmethod
282312
@track
283-
def make_llir(src, metadata, options):
313+
def make_llir(cls, src, metadata, options):
284314
mod = src
285315
# TritonGPU -> LLVM-IR (MLIR)
286316
pm = ir.pass_manager(mod.context)
@@ -292,8 +322,8 @@ def make_llir(src, metadata, options):
292322
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
293323
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
294324
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
295-
if XPUBackend.instrumentation:
296-
XPUBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
325+
if cls.instrumentation:
326+
cls.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
297327
intel.passes.ttgpuir.add_to_llvmir(pm)
298328
intel.passes.ttgpuir.add_gen_to_llvm(pm)
299329
passes.common.add_canonicalizer(pm)
@@ -307,8 +337,8 @@ def make_llir(src, metadata, options):
307337
if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
308338
passes.llvmir.add_di_scope(pm)
309339

310-
if XPUBackend.instrumentation:
311-
XPUBackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
340+
if cls.instrumentation:
341+
cls.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
312342
pm.run(mod, 'make_llir')
313343

314344
if knobs.compilation.dump_ir_extract_di_local_variables:
@@ -333,15 +363,12 @@ def make_llir(src, metadata, options):
333363
llvm.init_targets()
334364
context = llvm.context()
335365
llvm_mod = llvm.to_module(mod, context)
336-
intel.set_spv_target_triple(llvm_mod)
337366
intel.set_fast_math(llvm_mod)
338367
if options.extern_libs:
339368
paths = [path for (name, path) in options.extern_libs]
340369
llvm.link_extern_libs(llvm_mod, paths)
341370

342-
with track("optimize_module") as tr:
343-
intel.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, tr.callback("passes"))
344-
371+
cls.optimize_llvm_mod(llvm_mod, options)
345372
intel.post_process_llir(llvm_mod)
346373

347374
# Get some metadata
@@ -359,9 +386,9 @@ def make_llir(src, metadata, options):
359386
del context
360387
return ret
361388

362-
@staticmethod
389+
@classmethod
363390
@track
364-
def make_spv(src, metadata, options, device_arch):
391+
def make_spv(cls, src, metadata, options):
365392
spirv, name = intel.translate_to_spirv(src)
366393
metadata["name"] = name
367394
metadata.setdefault("build_flags", "")
@@ -380,8 +407,9 @@ def make_spv(src, metadata, options, device_arch):
380407
metadata["build_flags"] += " -cl-opt-disable"
381408
return spirv
382409

383-
@staticmethod
384-
def make_zebin(src, metadata, options, device_arch):
410+
@classmethod
411+
@track
412+
def make_zebin(cls, src, metadata, options):
385413
metadata["binary_ext"] = "zebin"
386414

387415
shader_dump_opt = ""
@@ -398,8 +426,8 @@ def make_zebin(src, metadata, options, device_arch):
398426
fbin = fsrc.name + '.o'
399427

400428
ocloc_cmd = [
401-
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, '-options',
402-
metadata["build_flags"] + shader_dump_opt
429+
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', cls.device_arch,
430+
'-options', metadata["build_flags"] + shader_dump_opt
403431
]
404432

405433
try:
@@ -437,9 +465,9 @@ def add_stages(self, stages, options, language):
437465
elif language == Language.GLUON:
438466
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
439467
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
440-
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch)
468+
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options)
441469
if options.generate_native_code:
442-
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options, self.device_arch)
470+
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options)
443471
if knobs.runtime.add_stages_inspection_hook is not None:
444472
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)
445473

0 commit comments

Comments
 (0)