Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

Version: 2.0.1+cpu 

/ _inductor / codegen / wrapper.py

import collections
import contextlib
import dataclasses
import functools
import hashlib
from itertools import count
from typing import Any, Dict, List

from torch._dynamo.utils import dynamo_timed

from .. import codecache, config, ir
from ..codecache import cpp_compile_command, get_code_path
from ..utils import cache_on_self, has_triton, sympy_dot, sympy_product
from ..virtualized import V
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter

pexpr = PythonPrinter().doprint


def buffer_reuse_key(node: ir.Buffer):
    size = node.get_size()
    stride = node.get_stride()
    last_element = sympy_dot([s - 1 for s in size], stride)
    return (
        node.get_device(),
        node.get_dtype(),
        V.graph.sizevars.simplify(sympy_product(size)),
        # Detect gaps in tensor storage caused by strides
        V.graph.sizevars.size_hint(last_element),
    )


def make_buffer_reuse(old, new, del_func, declare, ending, as_strided):
    assert old.get_dtype() == new.get_dtype()
    del_line = ""
    if old.get_name() not in V.graph.get_output_names():
        del_line = del_func(old.get_name())
    if old.get_size() == new.get_size() and old.get_stride() == new.get_stride():
        return f"{declare}{new.get_name()} = {old.get_name()}{del_line}{ending}"

    return (
        f"{declare}{new.get_name()} = {as_strided}({old.get_name()}, "
        f"{V.graph.sizevars.codegen_shape_tuple(new.get_size())}, "
        f"{V.graph.sizevars.codegen_shape_tuple(new.get_stride())}){del_line}{ending}"
    )


def make_buffer_allocation(buffer):
    device = buffer.get_device()
    dtype = buffer.get_dtype()
    shape = tuple(buffer.get_size())
    stride = tuple(buffer.get_stride())
    return (
        f"{buffer.get_name()} = empty_strided("
        f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
        f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
        f"device='{device.type}', dtype={dtype})"
    )


def make_cpp_buffer_allocation(buffer):
    from .cpp import DTYPE_TO_ATEN

    # TODO: map layout and device here
    dtype = buffer.get_dtype()
    shape = tuple(buffer.get_size())
    stride = tuple(buffer.get_stride())
    return (
        f"auto {buffer.get_name()} = at::empty_strided("
        f"{V.graph.sizevars.codegen_shape_tuple(shape)}, "
        f"{V.graph.sizevars.codegen_shape_tuple(stride)}, "
        f"{DTYPE_TO_ATEN[dtype]}); "
    )


class MemoryPlanningState:
    def __init__(self):
        super().__init__()
        self.reuse_pool: Dict[
            Any, List["FreeIfNotReusedLine"]
        ] = collections.defaultdict(list)

    def __contains__(self, key):
        return bool(self.reuse_pool.get(key, None))

    def pop(self, key) -> "FreeIfNotReusedLine":
        item = self.reuse_pool[key].pop()
        assert not item.is_reused
        return item

    def push(self, key, item: "FreeIfNotReusedLine"):
        assert not item.is_reused
        self.reuse_pool[key].append(item)


@dataclasses.dataclass
class EnterCudaDeviceContextManagerLine:
    device_idx: int

    def codegen(self, code: IndentedBuffer):
        # Note _DeviceGuard has less overhead than device, but only accepts
        # integers
        code.writeline(f"with torch.cuda._DeviceGuard({self.device_idx}):")


class ExitCudaDeviceContextManagerLine:
    pass


class MemoryPlanningLine:
    def plan(self, state: MemoryPlanningState) -> "MemoryPlanningLine":
        """First pass to find reuse"""
        return self

    def codegen(self, code: IndentedBuffer):
        """Second pass to output code"""
        pass


@dataclasses.dataclass
class AllocateLine(MemoryPlanningLine):
    node: ir.Buffer

    def plan(self, state: MemoryPlanningState):
        if self.node.get_name() in V.graph.removed_buffers:
            return NullLine()

        # try to reuse a recently freed buffer
        key = buffer_reuse_key(self.node)
        if key in state:
            free_line = state.pop(key)
            free_line.is_reused = True
            return ReuseLine(free_line.node, self.node)

        return self

    def codegen(self, code: IndentedBuffer):
        assert self.node.get_name() not in V.graph.removed_buffers
        code.writeline(make_buffer_allocation(self.node))


@dataclasses.dataclass
class CppAllocateLine(AllocateLine):
    def plan(self, state: MemoryPlanningState):
        if self.node.get_name() in V.graph.removed_buffers:
            return NullLine()

        # try to reuse a recently freed buffer
        key = buffer_reuse_key(self.node)

        if key in state:
            free_line = state.pop(key)
            free_line.is_reused = True
            return CppReuseLine(free_line.node, self.node)

        return self

    def codegen(self, code: IndentedBuffer):
        assert self.node.get_name() not in V.graph.removed_buffers
        code.writeline(make_cpp_buffer_allocation(self.node))


@dataclasses.dataclass
class FreeIfNotReusedLine(MemoryPlanningLine):
    node: ir.Buffer
    is_reused: bool = False

    def plan(self, state: MemoryPlanningState):
        assert not self.is_reused
        if self.node.get_name() in V.graph.removed_buffers:
            return NullLine()
        state.push(buffer_reuse_key(self.node), self)
        return self

    def codegen(self, code: IndentedBuffer):
        assert self.node.get_name() not in V.graph.removed_buffers
        if not self.is_reused:
            code.writeline(f"del {self.node.get_name()}")


@dataclasses.dataclass
class CppFreeIfNotReusedLine(FreeIfNotReusedLine):
    node: ir.Buffer
    is_reused: bool = False

    def codegen(self, code: IndentedBuffer):
        assert (self.node.get_name()) not in V.graph.removed_buffers
        if not self.is_reused:
            code.writeline(f"{self.node.get_name()}.reset();")


@dataclasses.dataclass
class ReuseLine(MemoryPlanningLine):
    node: ir.Buffer
    reused_as: ir.Buffer

    def plan(self, state: MemoryPlanningState):
        assert self.node.get_name() not in V.graph.removed_buffers
        assert self.reused_as.get_name() not in V.graph.removed_buffers
        return self

    def codegen(self, code: IndentedBuffer):
        assert self.node.get_name() not in V.graph.removed_buffers
        assert self.reused_as.get_name() not in V.graph.removed_buffers
        code.writeline(
            make_buffer_reuse(
                self.node,
                self.reused_as,
                del_func=lambda name: f"; del {name}",
                declare="",
                ending="",
                as_strided="as_strided",
            )
            + "  # reuse"
        )


@dataclasses.dataclass
class CppReuseLine(ReuseLine):
    node: ir.Buffer
    reused_as: ir.Buffer

    def codegen(self, code: IndentedBuffer):
        assert self.node.get_name() not in V.graph.removed_buffers
        assert self.reused_as.get_name() not in V.graph.removed_buffers
        code.writeline(
            make_buffer_reuse(
                self.node,
                self.reused_as,
                del_func=lambda name: f"; {name}.reset()",
                declare="auto ",
                ending=";",
                as_strided="at::as_strided",
            )
            + "  // reuse"
        )


@dataclasses.dataclass
class FreeLine(MemoryPlanningLine):
    node: ir.Buffer

    def plan(self, state: MemoryPlanningState):
        if self.node.get_name() in V.graph.removed_buffers:
            return NullLine()
        return self

    def codegen(self, code: IndentedBuffer):
        assert self.node.get_name() not in V.graph.removed_buffers
        code.writeline(f"del {self.node.get_name()}")


class NullLine(MemoryPlanningLine):
    pass


class WrapperCodeGen(CodeGen):
    """
    The outer wrapper that calls the kernels.
    """

    def __init__(self):
        super().__init__()
        self._names_iter = count()
        self.header = IndentedBuffer()
        self.prefix = IndentedBuffer()
        self.wrapper_call = IndentedBuffer()
        self.kernels = {}
        self.lines = []
        self.header.splice(
            f"""
                from ctypes import c_void_p, c_long
                import torch
                import math
                import random
                from torch import empty_strided, as_strided, device
                from {codecache.__name__} import AsyncCompile
                from torch._inductor.select_algorithm import extern_kernels

                aten = torch.ops.aten
                assert_size_stride = torch._C._dynamo.guards.assert_size_stride
                async_compile = AsyncCompile()

            """
        )

        if has_triton():
            self.header.splice(
                """
                import triton
                import triton.language as tl
                from torch._inductor.triton_ops.autotune import grid
                from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
                """
            )

        self.write_prefix()

        for name, value in V.graph.constants.items():
            # include a hash so our code cache gives different constants different files
            hashed = hashlib.sha256(repr(value).encode("utf-8")).hexdigest()
            self.header.writeline(f"{name} = None  # {hashed}")

        self.allocated = set()
        self.freed = set()

        # maps from reusing buffer to reused buffer
        self.reuses = dict()

        self.write_get_cuda_stream = functools.lru_cache(None)(
            self.write_get_cuda_stream
        )

        @functools.lru_cache(None)
        def add_import_once(line):
            self.header.writeline(line)

        self.add_import_once = add_import_once
        self._metas = {}

    def add_meta_once(self, meta):
        meta = repr(meta)
        if meta not in self._metas:
            var = f"meta{len(self._metas)}"
            self._metas[meta] = var
            self.header.writeline(f"{var} = {meta}")
        return self._metas[meta]

    @cache_on_self
    def get_output_refs(self):
        return [x.codegen_reference() for x in V.graph.graph_outputs]

    def write_prefix(self):
        self.prefix.splice(
            """

            async_compile.wait(globals())
            del async_compile

            def call(args):
            """
        )
        with self.prefix.indent():
            if config.triton.debug_sync_graph:
                self.prefix.writeline("torch.cuda.synchronize()")
            inp_len = len(V.graph.graph_inputs.keys())
            if inp_len != 0:
                lhs = f"{', '.join(V.graph.graph_inputs.keys())}{'' if inp_len != 1 else ','}"
                self.prefix.writeline(f"{lhs} = args")
                self.prefix.writeline("args.clear()")
            for name in V.graph.randomness_seeds:
                self.prefix.writeline(
                    f"torch.randint(2**31, size=(), dtype=torch.int64, out={name})"
                )
            V.graph.sizevars.codegen(self.prefix, V.graph.graph_inputs)

    def append_precomputed_sizes_to_prefix(self):
        with self.prefix.indent():
            V.graph.sizevars.codegen_precomputed_sizes(self.prefix)

    def write_get_cuda_stream(self, index):
        name = f"stream{index}"
        self.writeline(f"{name} = get_cuda_stream({index})")
        return name

    def next_kernel_suffix(self):
        return f"{next(self._names_iter)}"

    def write_allocate_line(self, buffer):
        self.writeline(AllocateLine(buffer))

    def get_deferred_line(self, name, layout):
        return DeferredLine(
            name, f"{name} = {layout.view.codegen_reference()}  # alias"
        )

    def codegen_allocation(self, buffer):
        name = buffer.get_name()
        if name in V.graph.removed_buffers or name in self.allocated:
            return
        self.allocated.add(name)

        if isinstance(
            buffer,
            (ir.ExternKernelAlloc, ir.MultiOutput),
        ):
            return

        layout = buffer.get_layout()
        if isinstance(layout, ir.MutationLayout):
            return
        if isinstance(layout, ir.AliasedLayout):
            assert isinstance(layout.view, ir.ReinterpretView)
            if not layout.maybe_guard_aligned():
                V.graph.unaligned_buffers.add(name)
            self.codegen_allocation(layout.view.data)
            allocation = self.get_deferred_line(name, layout)
            self.writeline(allocation)
            return

        self.write_allocate_line(buffer)

    def write_del_line(self, name):
        self.writeline(f"del {name}")

    def write_free_if_not_reused_line(self, buffer):
        self.writeline(FreeIfNotReusedLine(buffer))

    def codegen_free(self, buffer):
        name = buffer.get_name()

        # can be freed but not reused
        if isinstance(buffer, ir.InputBuffer):
            self.write_del_line(name)
            return

        if not self.can_reuse(buffer):
            return
        self.freed.add(name)

        layout = buffer.get_layout()
        if isinstance(layout, (ir.AliasedLayout, ir.MultiOutputLayout)):
            self.write_del_line(name)
            return

        self.write_free_if_not_reused_line(buffer)

    def can_reuse(self, buffer):
        name = buffer.get_name()
        if (
            name in V.graph.removed_buffers
            or name in V.graph.graph_inputs
            or name in V.graph.constants
            or name in self.freed
        ):
            return False
        return True

    def did_reuse(self, buffer, reused_buffer):
        # Check whether a given buffer was reused by a possible reuser in the wrapper codegen
        # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed
        return (
            buffer.get_name() in self.reuses
            and self.reuses[buffer.get_name()] == reused_buffer.get_name()
        )

    def write_reuse_line(self, input_buffer, output_buffer):
        self.writeline(ReuseLine(input_buffer, output_buffer))

    def codegen_inplace_reuse(self, input_buffer, output_buffer):
        assert buffer_reuse_key(input_buffer) == buffer_reuse_key(output_buffer)
        self.codegen_allocation(input_buffer)
        self.freed.add(input_buffer.get_name())
        self.allocated.add(output_buffer.get_name())
        self.reuses[output_buffer.get_name()] = input_buffer.get_name()
        self.write_reuse_line(input_buffer, output_buffer)

    def codegen_cuda_device_guard_enter(self, device_idx):
        self.lines.append(EnterCudaDeviceContextManagerLine(device_idx))

    def codegen_cuda_device_guard_exit(self):
        self.lines.append(ExitCudaDeviceContextManagerLine())

    def generate_return(self, output_refs):
        if output_refs:
            self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )")
        else:
            self.wrapper_call.writeline("return ()")

    def generate_end(self, result):
        return

    def generate_extern_kernel_out(
        self, output_view, codegen_reference, args, kernel, cpp_kernel
    ):
        if output_view:
            args.append(f"out={output_view.codegen_reference()}")
        else:
            args.append(f"out={codegen_reference}")
        self.writeline(f"{kernel}({', '.join(args)})")

    @dynamo_timed
    def generate(self):
        result = IndentedBuffer()
        result.splice(self.header)

        out_names = V.graph.get_output_names()
        with contextlib.ExitStack() as stack:
            stack.enter_context(self.wrapper_call.indent())
            if config.profiler_mark_wrapper_call:
                self.wrapper_call.writeline(
                    "from torch.profiler import record_function"
                )
                self.wrapper_call.writeline(
                    "with record_function('inductor_wrapper_call'):"
                )
                stack.enter_context(self.wrapper_call.indent())
            while (
                self.lines
                and isinstance(self.lines[-1], MemoryPlanningLine)
                and self.lines[-1].node.name not in out_names
            ):
                # these lines will be pointless
                self.lines.pop()

            # codegen allocations in two passes
            planning_state = MemoryPlanningState()
            for i in range(len(self.lines)):
                if isinstance(self.lines[i], MemoryPlanningLine):
                    self.lines[i] = self.lines[i].plan(planning_state)

            device_cm_stack = contextlib.ExitStack()
            for line in self.lines:
                if isinstance(line, MemoryPlanningLine):
                    line.codegen(self.wrapper_call)
                elif isinstance(line, EnterCudaDeviceContextManagerLine):
                    line.codegen(self.wrapper_call)
                    device_cm_stack.enter_context(self.wrapper_call.indent())
                    self.wrapper_call.writeline(
                        f"torch.cuda.set_device({line.device_idx}) # no-op to ensure context"
                    )
                elif isinstance(line, ExitCudaDeviceContextManagerLine):
                    device_cm_stack.close()
                else:
                    self.wrapper_call.writeline(line)

            output_refs = self.get_output_refs()
            if config.triton.debug_sync_graph:
                self.wrapper_call.writeline("torch.cuda.synchronize()")
            self.generate_return(output_refs)

        self.append_precomputed_sizes_to_prefix()
        result.splice(self.prefix)

        with result.indent():
            result.splice(self.wrapper_call)

        self.generate_end(result)

        self.add_benchmark_harness(result)

        return result.getvalue()

    def add_benchmark_harness(self, output):
        """
        Append a benchmark harness to generated code for debugging
        """
        if not config.benchmark_harness:
            return

        def add_fake_input(name, shape, stride, device, dtype):
            output.writeline(
                f"{name} = rand_strided("
                f"{V.graph.sizevars.codegen_benchmark_shape_tuple(shape)}, "
                f"{V.graph.sizevars.codegen_benchmark_shape_tuple(stride)}, "
                f"device='{device}', dtype={dtype})"
            )

        output.writelines(["", "", 'if __name__ == "__main__":'])
        with output.indent():
            output.splice(
                """
                from torch._dynamo.testing import rand_strided
                from torch._inductor.utils import print_performance
                """,
                strip=True,
            )

            for name, value in V.graph.constants.items():
                add_fake_input(
                    name, value.size(), value.stride(), value.device, value.dtype
                )

            for name, value in V.graph.graph_inputs.items():
                shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
                stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
                add_fake_input(
                    name, shape, stride, value.get_device(), value.get_dtype()
                )

            output.writeline(
                f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
            )

    def define_kernel(self, name: str, kernel: str):
        self.header.splice(f"\n\n{name} = {kernel}")

    def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
        return

    def wrap_kernel_call(self, name, call_args):
        return "{}({})".format(name, ", ".join(call_args))

    def generate_kernel_call(self, name, call_args):
        self.writeline(
            self.wrap_kernel_call(name, call_args),
        )

    def call_kernel(self, name: str, kernel: Kernel):
        tmp = IndentedBuffer()
        kernel.call_kernel(self, tmp, name)
        for line in tmp.getvalue().split("\n"):
            line = line.strip()
            if line:
                self.writeline(line)

    def writeline(self, line):
        self.lines.append(line)


class CppWrapperCodeGen(WrapperCodeGen):
    """
    The outer wrapper that calls the kernels.
    """

    call_func_id = count()

    def __init__(self):
        self._call_func_id = next(CppWrapperCodeGen.call_func_id)
        super().__init__()

    @cache_on_self
    def get_output_refs(self):
        def has_cpp_codegen_func(x):
            return hasattr(x, "cpp_wrapper_codegen_reference") and callable(
                x.cpp_wrapper_codegen_reference
            )

        return [
            x.cpp_wrapper_codegen_reference()
            if has_cpp_codegen_func(x)
            else x.codegen_reference()
            for x in V.graph.graph_outputs
        ]

    def write_prefix(self):
        self.prefix.splice(
            """
            async_compile.wait(globals())
            del async_compile
            from torch.utils.cpp_extension import load_inline
            wrapper = (
            '''
            #include <dlfcn.h>
            #include <assert.h>

            template <typename KernelFunc>
            KernelFunc load_cpp_kernel(const char* so_filename) {
                KernelFunc kernel_cpp;
                auto kernel_cpp_lib = dlopen(so_filename, RTLD_NOW);
                assert(kernel_cpp_lib != nullptr);
                *(void **) (&kernel_cpp) = dlsym(kernel_cpp_lib, "kernel");
                return kernel_cpp;
            }

            """
        )
        with self.wrapper_call.indent():
            inputs_len = len(V.graph.graph_inputs.keys())
            output_refs = self.get_output_refs()
            if output_refs:
                if len(output_refs) == 1:
                    output_types = "at::Tensor"
                else:
                    output_types = "std::vector<at::Tensor>"
            else:
                output_types = "void"

            inputs_types = "std::vector<at::Tensor>"
            self.wrapper_call.writeline(
                f"{output_types} call_{self._call_func_id}({inputs_types} args) {{"
            )
            if inputs_len != 0:
                inputs_keys_str = ", ".join(V.graph.graph_inputs.keys())
                self.wrapper_call.writeline(f"at::Tensor {inputs_keys_str};")
                for idx, input_key in enumerate(V.graph.graph_inputs.keys()):
                    self.wrapper_call.writeline(f"{input_key} = args[{idx}];")

            for name in V.graph.randomness_seeds:
                self.wrapper_call.writeline(f"at::Tensor {name};")
                self.wrapper_call.writeline(
                    f"{name} = at::randint(std::pow(2, 31), {{}}, at::ScalarType::Long);"
                )
            V.graph.sizevars.codegen(self.wrapper_call, V.graph.graph_inputs)

    def write_allocate_line(self, buffer):
        self.writeline(CppAllocateLine(buffer))

    def write_del_line(self, name):
        self.writeline(f"{name}.reset();")
        return

    def write_free_if_not_reused_line(self, buffer):
        self.writeline(CppFreeIfNotReusedLine(buffer))
        return

    def write_reuse_line(self, input_buffer, output_buffer):
        self.writeline(CppReuseLine(input_buffer, output_buffer))

    def get_deferred_line(self, name, layout):
        return DeferredLine(
            name, f"auto {name} = {layout.view.codegen_reference()};  // alias"
        )

    def get_kernel_path(self, code):
        from ..codecache import pick_vec_isa

        picked_vec_isa = pick_vec_isa()
        ext = "so"
        extra = cpp_compile_command("i", "o", vec_isa=picked_vec_isa)
        # \n is required to match with the CodeCache behavior
        #  For reductions, the code string gotten from code.getvalue() will use backslash '\'
        # at the end of lines for readability purpose:
        #       #pragma omp declare reduction(xxx :\
        #                       omp_out.value = xxx,\
        # While the code string loaded during the execution will escape the backslash '\':
        #       #pragma omp declare reduction(xxx :                omp_out.value = xxx,
        # Use code.getrawvalue() here to escape the backslash to
        # make sure the same code string is used during compilation and execution,
        # so that the hash value is the same.
        source_code = "\n" + code.getrawvalue()
        _, _, kernel_path = get_code_path(source_code, ext, extra)
        return kernel_path

    def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
        kernel_path = self.get_kernel_path(kernel)
        self.writeline(
            f'static auto {name} = load_cpp_kernel<void (*)({arg_types})>("{kernel_path}");'
        )

    def wrap_kernel_call(self, name, call_args):
        return "{}({});".format(name, ", ".join(call_args))

    def generate_return(self, output_refs):
        if output_refs:
            if len(output_refs) == 1:
                self.wrapper_call.writeline("return " + output_refs[0] + "; }''' )")
            else:
                self.wrapper_call.writeline(
                    "return std::vector<at::Tensor>({"
                    + ", ".join(output_refs)
                    + "}); }''' )"
                )
        else:
            self.wrapper_call.writeline("return; }''' )")

    def generate_end(self, result):
        shared = codecache.get_shared()
        warning_all_flag = codecache.get_warning_all_flag()
        cpp_flags = codecache.cpp_flags()
        ipaths, lpaths, libs, macros = codecache.get_include_and_linking_paths()
        optimization_flags = codecache.optimization_flags()
        use_custom_generated_macros = codecache.use_custom_generated_macros()

        extra_cflags = f"{cpp_flags} {optimization_flags} {warning_all_flag} {macros} {use_custom_generated_macros}"
        extra_ldflags = f"{shared} {lpaths} {libs}"
        extra_include_paths = f"{ipaths}"

        # get the hash of the wrapper code to name the extension
        wrapper_call_hash = codecache.code_hash(self.wrapper_call.getvalue())
        result.splice(
            f"""
            module = load_inline(
                name='inline_extension_{wrapper_call_hash}',
                cpp_sources=[wrapper],
                functions=['call_{self._call_func_id}'],
                extra_cflags=['{extra_cflags}'],
                extra_ldflags=['{extra_ldflags}'],
                extra_include_paths=['{extra_include_paths}'])
            """
        )
        # Wrap the func to support setting result._boxed_call = True
        result.splice(
            f"""
            def _wrap_func(f):
                def g(args):
                    return f(args)
                return g
            call = _wrap_func(module.call_{self._call_func_id})
            """
        )

    def generate_extern_kernel_out(
        self, output_view, codegen_reference, args, kernel, cpp_kernel
    ):
        if output_view:
            output_as_strided = f"{output_view.codegen_reference()}"
            output_name = f"{output_view.get_name()}_as_strided"
            self.writeline(f"auto {output_name} = {output_as_strided};")

            args.insert(0, output_name)
        else:
            args.insert(0, f"{codegen_reference}")
        self.writeline(f"{cpp_kernel}({', '.join(args)});")