1- from triton .backends .compiler import BaseBackend , Language
1+ from triton .backends .compiler import BaseBackend , GPUTarget , Language
22from triton ._C .libtriton import ir , passes , llvm , intel
33from triton .backends .intel .driver import compile_module_from_src
44from triton .backends .intel .track import track
1515import subprocess
1616from pathlib import Path
1717
18+ try : # XPUBackend allows metaclasses injection
19+ from .meta import XPUBackendMeta
20+ except ImportError :
21+ XPUBackendMeta = type (BaseBackend )
22+
1823
1924@dataclass
2025class XPUOptions :
@@ -63,40 +68,41 @@ def hash(self):
6368 return hashlib .sha256 (key .encode ("utf-8" )).hexdigest ()
6469
6570
66- def min_dot_size (device_props : dict ):
67- # (M, N, K)
68- # M: repeatCount. 1,2,4,8
69- # N: executionSize. 16 for PVC, 8 for ATS
70- # K: systolicDepth x opsPerChan. systolicDepth must be 8
71- repeat_count = 1
72- sdepth = 8
73- exec_size = min (device_props ["sub_group_sizes" ])
74-
75- def get_ops_per_channel (lhs_type , rhs_type ):
76- l_bitwidth = lhs_type .scalar .primitive_bitwidth
77- r_bitwidth = rhs_type .scalar .primitive_bitwidth
78- max_ops_per_chan = 32 / max (l_bitwidth , r_bitwidth )
79- return min (8 , max_ops_per_chan )
80-
81- return lambda lhs_type , rhs_type : (repeat_count , exec_size , sdepth * get_ops_per_channel (lhs_type , rhs_type ))
82-
83-
84- class XPUBackend (BaseBackend ):
71+ class XPUBackend (BaseBackend , metaclass = XPUBackendMeta ):
72+ arch_to_impl = {} # Architecture id to backend implementation class mapping
73+ binary_ext = "spv"
74+ target_arch = "spir64"
8575 instrumentation = None
8676
8777 @staticmethod
88- def supports_target (target : tuple ):
78+ def supports_target (target : GPUTarget ):
8979 return target .backend == 'xpu'
9080
91- def __init__ (self , target : tuple ) -> None :
92- super ().__init__ (target )
81+ def __new__ (cls , target : GPUTarget ):
9382 if not isinstance (target .arch , dict ):
9483 raise TypeError ("target.arch is not a dict" )
95- dirname = os .path .dirname (os .path .realpath (__file__ ))
96- mod = compile_module_from_src (src = Path (os .path .join (dirname , "arch_parser.c" )).read_text (), name = "arch_utils" )
97- self .device_arch = knobs .intel .device_arch or mod .parse_device_arch (target .arch .get ('architecture' , 0 ))
84+ if cls is not XPUBackend :
85+ return super ().__new__ (cls )
86+ arch = target .arch .get ("architecture" , 0 )
87+ if (impl := cls .arch_to_impl .get (arch , None )) is None :
88+ # Try to find an arch-specific implementation in the .arch.<name> submodule.
89+ if not (dev_arch := knobs .intel .device_arch ):
90+ dirname = os .path .dirname (os .path .realpath (__file__ ))
91+ parser = compile_module_from_src (src = Path (os .path .join (dirname , "arch_parser.c" )).read_text (),
92+ name = "arch_utils" )
93+ dev_arch = parser .parse_device_arch (target .arch .get ('architecture' , 0 ))
94+ mod_name = f"{ __package__ } .arch.{ dev_arch } "
95+ try :
96+ impl = __import__ (mod_name , fromlist = ["XPUBackendImpl" ]).XPUBackendImpl
97+ except ImportError :
98+ impl = type (f"{ mod_name } .XPUBackendImpl" , (cls , ), {})
99+ impl .device_arch = dev_arch
100+ cls .arch_to_impl [arch ] = impl
101+ return super ().__new__ (impl )
102+
103+ def __init__ (self , target : GPUTarget ) -> None :
104+ super ().__init__ (target )
98105 self .properties = self .parse_target (target .arch )
99- self .binary_ext = "spv"
100106
101107 def get_target_name (self , options ) -> str :
102108 return f"xpu:{ self .device_arch } "
@@ -123,18 +129,36 @@ def parse_target(self, tgt_prop) -> dict:
123129 return dev_prop
124130
125131 def parse_options (self , opts ) -> Any :
126- args = {k : opts [ k ] for k in XPUOptions . __dataclass_fields__ . keys () if k in opts }
132+ args = {k : v for k , v in opts . items () if k in XPUOptions . __dataclass_fields__ }
127133 args ["allow_fp8e4nv" ] = True
128134 return XPUOptions (** args )
129135
130136 def pack_metadata (self , metadata ):
131137 return metadata
132138
139+ @staticmethod
140+ def min_dot_size (device_props : dict ):
141+ # (M, N, K)
142+ # M: repeatCount. 1,2,4,8
143+ # N: executionSize. 16 for PVC, 8 for ATS
144+ # K: systolicDepth x opsPerChan. systolicDepth must be 8
145+ repeat_count = 1
146+ sdepth = 8
147+ exec_size = min (device_props ["sub_group_sizes" ])
148+
149+ def get_ops_per_channel (lhs_type , rhs_type ):
150+ l_bitwidth = lhs_type .scalar .primitive_bitwidth
151+ r_bitwidth = rhs_type .scalar .primitive_bitwidth
152+ max_ops_per_chan = 32 / max (l_bitwidth , r_bitwidth )
153+ return min (8 , max_ops_per_chan )
154+
155+ return lambda lhs_type , rhs_type : (repeat_count , exec_size , sdepth * get_ops_per_channel (lhs_type , rhs_type ))
156+
133157 def get_codegen_implementation (self , options ):
134158 from triton .language .extra .intel import convert_custom_float8
135159 codegen_fns = {}
136160 codegen_fns ["convert_custom_types" ] = convert_custom_float8
137- codegen_fns ["min_dot_size" ] = min_dot_size (self .properties )
161+ codegen_fns ["min_dot_size" ] = self . min_dot_size (self .properties )
138162 return codegen_fns
139163
140164 def get_module_map (self ) -> Dict [str , ModuleType ]:
@@ -143,8 +167,8 @@ def get_module_map(self) -> Dict[str, ModuleType]:
143167
144168 def load_dialects (self , ctx ):
145169 intel .load_dialects (ctx )
146- if XPUBackend .instrumentation :
147- XPUBackend .instrumentation .load_dialects (ctx )
170+ if self .instrumentation :
171+ self .instrumentation .load_dialects (ctx )
148172
149173 @staticmethod
150174 def validate_options (opt , properties ):
@@ -158,20 +182,15 @@ def validate_options(opt, properties):
158182 f"num_warps={ opt .num_warps } is unsupported for the target (limit is { properties ['max_num_sub_groups' ]} )"
159183 )
160184
161- @staticmethod
162- def annotate_module (mod , properties , opt , target_arch ):
185+ @classmethod
186+ def annotate_module (cls , module_opts , properties , opt ):
163187 # Annotate module with information required by subsequent transformations.
164- pm = ir .pass_manager (mod .context )
165- pm .enable_debug ()
166- module_opts = intel .passes .ttgpuir .AnnotateModuleOptions ()
167188 module_opts .min_sg_size = min (properties ["sub_group_sizes" ])
168189 module_opts .support_sg_2d_block = properties ["has_subgroup_2d_block_io" ]
169190 module_opts .support_dpas = properties ["has_subgroup_matrix_multiply_accumulate" ]
170191 module_opts .support_bf16_conversion = properties ["has_bfloat16_conversions" ]
171192 module_opts .threads_per_warp = opt .warp_size
172- module_opts .target_arch = target_arch
173- intel .passes .ttgpuir .add_triton_annotate_module (pm , module_opts )
174- pm .run (mod , 'annotate_module' )
193+ module_opts .target_arch = cls .target_arch
175194
176195 @staticmethod
177196 def get_split_barrier_scope (opt ):
@@ -182,9 +201,9 @@ def get_split_barrier_scope(opt):
182201 split_barriers_scope = intel .SplitBarrierScope .Subgroup
183202 return split_barriers_scope
184203
185- @staticmethod
204+ @classmethod
186205 @track
187- def make_ttir (mod , metadata , opt ):
206+ def make_ttir (cls , mod , metadata , opt ):
188207 pm = ir .pass_manager (mod .context )
189208 pm .enable_debug ()
190209 passes .common .add_inliner (pm )
@@ -204,21 +223,26 @@ def make_ttir(mod, metadata, opt):
204223 pm .run (mod , 'make_ttir' )
205224 return mod
206225
207- @staticmethod
226+ @classmethod
208227 @track
209- def make_ttgir (mod , metadata , opt , properties ):
228+ def make_ttgir (cls , mod , metadata , opt , properties ):
210229 cluster_info = intel .ClusterInfo ()
211230 if opt .cluster_dims is not None :
212231 cluster_info .clusterDimX = opt .cluster_dims [0 ]
213232 cluster_info .clusterDimY = opt .cluster_dims [1 ]
214233 cluster_info .clusterDimZ = opt .cluster_dims [2 ]
215234
216235 # Annotate module with information required by subsequent transformations.
217- XPUBackend .annotate_module (mod , properties , opt , "spir64" )
236+ pm = ir .pass_manager (mod .context )
237+ pm .enable_debug ()
238+ module_opts = intel .passes .ttgpuir .AnnotateModuleOptions ()
239+ cls .annotate_module (module_opts , properties , opt )
240+ intel .passes .ttgpuir .add_triton_annotate_module (pm , module_opts )
241+ pm .run (mod , 'annotate_module' )
218242
219243 # Overwrite the warp_size option with the module annotation.
220244 opt .warp_size = intel .get_threads_per_warp (mod )
221- XPUBackend .validate_options (opt , properties )
245+ cls .validate_options (opt , properties )
222246
223247 pm = ir .pass_manager (mod .context )
224248 pm .enable_debug ()
@@ -278,9 +302,15 @@ def gluon_to_ttgir(self, src, metadata, options):
278302 metadata ["tensordesc_meta" ] = mod .get_tensordesc_metadata ()
279303 return mod
280304
281- @staticmethod
305+ @classmethod
306+ def optimize_llvm_mod (cls , llvm_mod , options ):
307+ intel .set_spv_target_triple (llvm_mod )
308+ with track ("optimize_module" ) as tr :
309+ intel .optimize_module (llvm_mod , llvm .OPTIMIZE_O3 , tr .callback ("passes" ))
310+
311+ @classmethod
282312 @track
283- def make_llir (src , metadata , options ):
313+ def make_llir (cls , src , metadata , options ):
284314 mod = src
285315 # TritonGPU -> LLVM-IR (MLIR)
286316 pm = ir .pass_manager (mod .context )
@@ -292,8 +322,8 @@ def make_llir(src, metadata, options):
292322 intel .passes .ttgpuir .add_allocate_shared_memory (pm )
293323 passes .ttgpuir .add_allocate_global_scratch_memory (pm )
294324 # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
295- if XPUBackend .instrumentation :
296- XPUBackend .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context )
325+ if cls .instrumentation :
326+ cls .instrumentation .patch ("ttgpuir_to_llvmir" , pm , mod .context )
297327 intel .passes .ttgpuir .add_to_llvmir (pm )
298328 intel .passes .ttgpuir .add_gen_to_llvm (pm )
299329 passes .common .add_canonicalizer (pm )
@@ -307,8 +337,8 @@ def make_llir(src, metadata, options):
307337 if not knobs .compilation .disable_line_info and not knobs .compilation .dump_ir_extract_di_local_variables :
308338 passes .llvmir .add_di_scope (pm )
309339
310- if XPUBackend .instrumentation :
311- XPUBackend .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context )
340+ if cls .instrumentation :
341+ cls .instrumentation .patch ("llvmir_to_llvm" , pm , mod .context )
312342 pm .run (mod , 'make_llir' )
313343
314344 if knobs .compilation .dump_ir_extract_di_local_variables :
@@ -333,15 +363,12 @@ def make_llir(src, metadata, options):
333363 llvm .init_targets ()
334364 context = llvm .context ()
335365 llvm_mod = llvm .to_module (mod , context )
336- intel .set_spv_target_triple (llvm_mod )
337366 intel .set_fast_math (llvm_mod )
338367 if options .extern_libs :
339368 paths = [path for (name , path ) in options .extern_libs ]
340369 llvm .link_extern_libs (llvm_mod , paths )
341370
342- with track ("optimize_module" ) as tr :
343- intel .optimize_module (llvm_mod , llvm .OPTIMIZE_O3 , tr .callback ("passes" ))
344-
371+ cls .optimize_llvm_mod (llvm_mod , options )
345372 intel .post_process_llir (llvm_mod )
346373
347374 # Get some metadata
@@ -359,9 +386,9 @@ def make_llir(src, metadata, options):
359386 del context
360387 return ret
361388
362- @staticmethod
389+ @classmethod
363390 @track
364- def make_spv (src , metadata , options , device_arch ):
391+ def make_spv (cls , src , metadata , options ):
365392 spirv , name = intel .translate_to_spirv (src )
366393 metadata ["name" ] = name
367394 metadata .setdefault ("build_flags" , "" )
@@ -380,8 +407,9 @@ def make_spv(src, metadata, options, device_arch):
380407 metadata ["build_flags" ] += " -cl-opt-disable"
381408 return spirv
382409
383- @staticmethod
384- def make_zebin (src , metadata , options , device_arch ):
410+ @classmethod
411+ @track
412+ def make_zebin (cls , src , metadata , options ):
385413 metadata ["binary_ext" ] = "zebin"
386414
387415 shader_dump_opt = ""
@@ -398,8 +426,8 @@ def make_zebin(src, metadata, options, device_arch):
398426 fbin = fsrc .name + '.o'
399427
400428 ocloc_cmd = [
401- 'ocloc' , 'compile' , '-file' , fsrc .name , '-o' , fbin , '-spirv_input' , '-device' , device_arch , '-options' ,
402- metadata ["build_flags" ] + shader_dump_opt
429+ 'ocloc' , 'compile' , '-file' , fsrc .name , '-o' , fbin , '-spirv_input' , '-device' , cls . device_arch ,
430+ '-options' , metadata ["build_flags" ] + shader_dump_opt
403431 ]
404432
405433 try :
@@ -437,9 +465,9 @@ def add_stages(self, stages, options, language):
437465 elif language == Language .GLUON :
438466 stages ["ttgir" ] = lambda src , metadata : self .gluon_to_ttgir (src , metadata , options )
439467 stages ["llir" ] = lambda src , metadata : self .make_llir (src , metadata , options )
440- stages ["spv" ] = lambda src , metadata : self .make_spv (src , metadata , options , self . device_arch )
468+ stages ["spv" ] = lambda src , metadata : self .make_spv (src , metadata , options )
441469 if options .generate_native_code :
442- stages ["zebin" ] = lambda src , metadata : self .make_zebin (src , metadata , options , self . device_arch )
470+ stages ["zebin" ] = lambda src , metadata : self .make_zebin (src , metadata , options )
443471 if knobs .runtime .add_stages_inspection_hook is not None :
444472 knobs .runtime .add_stages_inspection_hook (self , stages , options , language , None )
445473
0 commit comments