|  | 
|  | 1 | +# coding: utf-8 | 
|  | 2 | +#------------------------------------------------------------------------------------------# | 
|  | 3 | +# This file is part of Pyccel which is released under MIT License. See the LICENSE file or # | 
|  | 4 | +# go to https://github.com/pyccel/pyccel/blob/master/LICENSE for full license details.     # | 
|  | 5 | +#------------------------------------------------------------------------------------------# | 
|  | 6 | +""" | 
|  | 7 | +Provide tools for generating and handling CUDA code. | 
|  | 8 | +This module is designed to interface Pyccel's Abstract Syntax Tree (AST) with CUDA, | 
|  | 9 | +enabling the direct translation of high-level Pyccel expressions into CUDA code. | 
|  | 10 | +""" | 
|  | 11 | + | 
|  | 12 | +from pyccel.codegen.printing.ccode import CCodePrinter, c_library_headers | 
|  | 13 | + | 
|  | 14 | +from pyccel.ast.core        import Import, Module | 
|  | 15 | + | 
|  | 16 | +from pyccel.errors.errors   import Errors | 
|  | 17 | + | 
|  | 18 | + | 
|  | 19 | +errors = Errors() | 
|  | 20 | + | 
|  | 21 | +__all__ = ["CudaCodePrinter"] | 
|  | 22 | + | 
|  | 23 | +class CudaCodePrinter(CCodePrinter): | 
|  | 24 | +    """ | 
|  | 25 | +    Print code in CUDA format. | 
|  | 26 | +
 | 
|  | 27 | +    This printer converts Pyccel's Abstract Syntax Tree (AST) into strings of CUDA code. | 
|  | 28 | +    Navigation through this file utilizes _print_X functions, | 
|  | 29 | +    as is common with all printers. | 
|  | 30 | +
 | 
|  | 31 | +    Parameters | 
|  | 32 | +    ---------- | 
|  | 33 | +    filename : str | 
|  | 34 | +            The name of the file being pyccelised. | 
|  | 35 | +    prefix_module : str | 
|  | 36 | +            A prefix to be added to the name of the module. | 
|  | 37 | +    """ | 
|  | 38 | +    language = "cuda" | 
|  | 39 | + | 
|  | 40 | +    def __init__(self, filename, prefix_module = None): | 
|  | 41 | + | 
|  | 42 | +        errors.set_target(filename, 'file') | 
|  | 43 | + | 
|  | 44 | +        super().__init__(filename) | 
|  | 45 | + | 
|  | 46 | +    def _print_Module(self, expr): | 
|  | 47 | +        self.set_scope(expr.scope) | 
|  | 48 | +        self._current_module = expr.name | 
|  | 49 | +        body = ''.join(self._print(i) for i in expr.body) | 
|  | 50 | + | 
|  | 51 | +        global_variables = ''.join(self._print(d) for d in expr.declarations) | 
|  | 52 | + | 
|  | 53 | +        # Print imports last to be sure that all additional_imports have been collected | 
|  | 54 | +        imports = [Import(expr.name, Module(expr.name,(),())), *self._additional_imports.values()] | 
|  | 55 | +        c_headers_imports = '' | 
|  | 56 | +        local_imports = '' | 
|  | 57 | + | 
|  | 58 | +        for imp in imports: | 
|  | 59 | +            if imp.source in c_library_headers: | 
|  | 60 | +                c_headers_imports += self._print(imp) | 
|  | 61 | +            else: | 
|  | 62 | +                local_imports += self._print(imp) | 
|  | 63 | + | 
|  | 64 | +        imports = f'{c_headers_imports}\ | 
|  | 65 | +                    extern "C"{{\n\ | 
|  | 66 | +                    {local_imports}\ | 
|  | 67 | +                    }}' | 
|  | 68 | + | 
|  | 69 | +        code = f'{imports}\n\ | 
|  | 70 | +                 {global_variables}\n\ | 
|  | 71 | +                 {body}\n' | 
|  | 72 | + | 
|  | 73 | +        self.exit_scope() | 
|  | 74 | +        return code | 
0 commit comments