Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / torch   python

Repository URL to install this package:

/ profiler / _pattern_matcher.py

import json
import math
import os
import re
from typing import Dict, List, Optional, Set

import torch
from torch.profiler import profile
import torch.utils.benchmark as benchmark
from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs
from torch._C._profiler import (_ProfilerEvent, _ExtraFields_TorchOp,
                                _ExtraFields_PyCCall, _ExtraFields_PyCall,
                                _EventType)


class Pattern:
    '''
    Base class for all patterns, subclass this class and implement match()
    to define custom patterns.

    In subclass, define description and skip property.
    '''

    def __init__(self, prof: profile, should_benchmark: bool = False):
        self.prof = prof
        self.should_benchmark = should_benchmark
        self.name = "Please specify a name for pattern"
        self.description = "Please specify a description for pattern"
        self.url = ""
        assert prof.profiler is not None and prof.profiler.kineto_results is not None
        self.event_tree = prof.profiler.kineto_results.experimental_event_tree(
        )
        self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
        for event in self.event_tree:
            self.tid_root.setdefault(event.start_tid, []).append(event)

    @property
    def skip(self):
        return False

    def report(self, event: _ProfilerEvent):
        msg = f"{self.description}\n[Source Code Location] {source_code_location(event)}"
        return msg

    def eventTreeTraversal(self):
        '''
        Traverse the event tree and yield all events.
        Override this method in subclass to customize the traversal.
        '''
        yield from traverse_dfs(self.event_tree)

    def summary(self, events: List[_ProfilerEvent]):
        default_summary = f"{self.name}: {len(events)} events matched."
        if self.should_benchmark:
            # If benchmark summary is not empty, use it.
            return self.benchmark_summary(
                events) if hasattr(  # type: ignore[attr-defined]
                    self, 'benchmark') else default_summary
        return default_summary

    def benchmark_summary(self, events: List[_ProfilerEvent]):

        def format_time(time_ns: int):
            unit_lst = ["ns", "us", "ms"]
            for unit in unit_lst:
                if time_ns < 1000:
                    return f"{time_ns:.2f} {unit}"
                time_ns //= 1000
            return f"{time_ns:.2f} s"

        assert hasattr(self, 'benchmark'), 'Please implement benchmark()'
        shapes_factor_map = self.benchmark(  # type: ignore[attr-defined]
            events)
        original_time = sum(event.duration_time_ns for event in events)
        new_time = sum(shapes_factor_map[input_shapes(event)] *
                       event.duration_time_ns for event in events)
        return (
            f"{self.name}: {len(events)} events matched. "
            f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)"
        )

    def match(self, event: _ProfilerEvent):
        '''
        Return True if the event matches the pattern.
        This method should be overriden in subclass.
        '''
        raise NotImplementedError

    def matched_events(self):
        if self.skip:
            return []
        matched_events = []
        for event in self.eventTreeTraversal():
            if self.match(event):
                matched_events.append(event)
        return matched_events

    def root_of(self, event: _ProfilerEvent):
        while event.parent:
            event = event.parent
        return event

    def siblings_of(self, event: _ProfilerEvent):
        if event.parent:
            children = event.parent.children
        else:
            children = self.tid_root[event.start_tid]
        index = children.index(event)
        return children[:index], children[index + 1:]

    def next_of(self, event: _ProfilerEvent):
        _, next_events = self.siblings_of(event)
        return next_events[0] if next_events else None

    def prev_of(self, event: _ProfilerEvent):
        prev_events, _ = self.siblings_of(event)
        return prev_events[-1] if prev_events else None

    def go_up_until(self, event: _ProfilerEvent, predicate):
        if not event:
            return None
        while event.parent and not predicate(event):
            event = event.parent
        return event


# Patterns


class NamePattern(Pattern):

    def __init__(self,
                 prof: profile,
                 name: str,
                 should_benchmark: bool = False):
        super().__init__(prof, should_benchmark)
        self.description = f"Matched Name Event: {name}"
        self.name = name

    def match(self, event: _ProfilerEvent):
        return re.search(self.name, event.name) is not None


class ExtraCUDACopyPattern(Pattern):
    '''
    This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU.
    example: torch.zeros((100, 100)).to("cuda")

    Pattern:
    build-in method                 |build-in method
        ...                         |    aten::to
            aten::fill_/aten::zero_ |        aten::_to_copy

    Algorithm:
    We start at node aten::to, go parent events' previous events,
    and check if we have a aten::fill_/aten::zero_ as we keep going down the tree.
    We always select the last child in the children list when we go down the tree.
    If at any step we failed, it is not a match.
    '''

    def __init__(self, prof: profile, should_benchmark: bool = False):
        super().__init__(prof, should_benchmark)
        self.name = "Extra CUDA Copy Pattern"
        self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
        self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
        self.init_ops = {
            "aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_"
        }

    @property
    def skip(self):
        return not self.prof.with_stack or not self.prof.record_shapes

    def match(self, event):
        # TODO: We should also check tensor identities
        if event.name != "aten::to":
            return False
        to_event = event
        if not event.children:
            return False
        event = event.children[-1]
        if event.name != "aten::_to_copy":
            return False
        if not event.children:
            return False
        event = event.children[-1]
        if event.name != "aten::copy_":
            return False
        # aten::copy_ should have the first 2 args dtype the same
        dtypes = input_dtypes(event)
        if len(dtypes) < 2:
            return False
        if dtypes[0] is None or dtypes[0] != dtypes[1]:
            return False
        event = to_event
        # Up one level
        event = event.parent
        if event is None:
            return False
        # Check if we have a aten::fill_ in previous leaf
        event = self.prev_of(event)
        if event is None:
            return False
        while event.children:
            event = event.children[-1]
            # aten::zero_ is a special optimzation case where fill_ is not called
            if event.name in self.init_ops:
                return True
        return event.name in self.init_ops
        # TODO: Check if tensor is reused

    def benchmark(self, events: List[_ProfilerEvent]):
        shapes_factor_map = {input_shapes(event): 0.0 for event in events}
        for shape in shapes_factor_map:
            size = shape[0]
            to_timer = benchmark.Timer(stmt='torch.ones(size).to("cuda")',
                                       globals={'size': size})
            de_timer = benchmark.Timer(stmt='torch.ones(size, device="cuda")',
                                       globals={'size': size})
            to_time = to_timer.timeit(10).mean
            de_time = de_timer.timeit(10).mean
            shapes_factor_map[shape] = de_time / to_time
        return shapes_factor_map


class ForLoopIndexingPattern(Pattern):
    '''
    This pattern identifies if we use a for loop to index a tensor that
    can be vectorized.
    example:
    tensor = torch.empty((100, 100))
    for i in range(100):
        tensor[i] = i

    Pattern:
    aten::select | ... | aten::select | ... (Repeat)

    Algorithm:
    We start at node aten::select, and we check if we can find this alternating patterns.
    We also keep a dictionary to avoid duplicate match in the for loop.
    '''

    def __init__(self, prof: profile, should_benchmark: bool = False):
        super().__init__(prof, should_benchmark)
        self.name = "For Loop Indexing Pattern"
        self.description = "For loop indexing detected. Vectorization recommended."
        self.visited: Set[int] = set()

    def eventTreeTraversal(self):
        '''
        We need to use BFS traversal order to avoid duplicate match.
        '''
        yield from traverse_bfs(self.event_tree)

    def match(self, event: _ProfilerEvent):
        if event.name != "aten::select":
            return False
        if event.id in self.visited:
            return False
        repeat_count = 1
        _, next = self.siblings_of(event)
        if len(next) <= 1:
            return False

        # Custom event list matching
        def same_ops(list1, list2):
            if len(list1) != len(list2):
                return False
            for op1, op2 in zip(list1, list2):
                if op1.name != op2.name:
                    return False
            return True

        # Record the ops between two aten::select
        next_select_idx = index_of_first_match(
            next, lambda e: e.name == "aten::select")
        if next_select_idx is None:
            return False
        indexing_ops = [event] + next[:next_select_idx]
        next = next[len(indexing_ops) - 1:]
        for i in range(0, len(next), len(indexing_ops)):
            if same_ops(indexing_ops, next[i:i + len(indexing_ops)]):
                repeat_count += 1
                self.visited.add(next[i].id)
            else:
                break
        return repeat_count >= 10


class FP32MatMulPattern(Pattern):

    def __init__(self, prof: profile, should_benchmark: bool = False):
        super().__init__(prof, should_benchmark)
        self.name = "FP32 MatMul Pattern"
        self.description = (
            "You are currently using GPU that supports TF32. "
            "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
        )
        self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"

    @property
    def skip(self):
        if torch.version.hip is not None:
            has_tf32 = False
        else:
            # Anything less than sm_80 is not Ampere which doesn't support TF32
            has_tf32 = all(
                int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
        return has_tf32 is False or super().skip or not self.prof.record_shapes

    def match(self, event: _ProfilerEvent):
        # If we saw this pattern once, we don't need to match it again
        if event.tag != _EventType.TorchOp:
            return False
        assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
        if event.name == "aten::mm":
            if event.extra_fields.allow_tf32_cublas is False:
                return True
        return False

    def report(self, event: _ProfilerEvent):
        return self.description

    def benchmark(self, events: List[_ProfilerEvent]):
        shapes_factor_map = {input_shapes(event): 0.0 for event in events}
        for shape in shapes_factor_map:
            matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
            matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32)
            fp32_timer = benchmark.Timer(stmt='torch.mm(matrixA, matrixB)',
                                         globals={
                                             "matrixA": matrixA,
                                             "matrixB": matrixB
                                         })
            tf32_timer = benchmark.Timer(
                stmt='torch.mm(matrixA, matrixB)',
                setup='torch.backends.cuda.matmul.allow_tf32 = True',
                globals={
                    "matrixA": matrixA,
                    "matrixB": matrixB
                })
            torch.backends.cuda.matmul.allow_tf32 = False
            fp32_time = fp32_timer.timeit(10).mean
            tf32_time = tf32_timer.timeit(10).mean
            shapes_factor_map[shape] = tf32_time / fp32_time
        return shapes_factor_map
Loading ...