Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

neilisaac / torch   python

Repository URL to install this package:

Version: 1.8.0 

/ core / nomnigraph / op_gen.py

#!/usr/bin/env python






import argparse
from textwrap import dedent
from subprocess import call


def parse_lines(lines):
    # States
    EMPTY = 0
    OP = 1
    MACRO = 2
    parse_state = EMPTY

    # Preprocess the macros
    curr_macro = ""
    macros = {}

    index = 0
    while index < len(lines):
        line = lines[index]
        if line.lower().startswith("macro"):
            assert parse_state == EMPTY
            macro_line = line.split(" ")
            # Support macros that look like attributes
            # e.g. macro - CONV_LIKE
            curr_macro = " ".join(macro_line[1:])
            assert curr_macro not in macros, 'Macro "{}" defined twice.'.format(
                curr_macro
            )
            macros[curr_macro] = []
            parse_state = MACRO
            lines = lines[:index] + lines[index + 1 :]
            continue
        elif line.lower().startswith("endmacro"):
            assert parse_state == MACRO
            parse_state = EMPTY
            lines = lines[:index] + lines[index + 1 :]
            continue
        elif parse_state == MACRO:
            macros[curr_macro].append(line)
            lines = lines[:index] + lines[index + 1 :]
            continue
        index += 1

    index = 0
    while index < len(lines):
        line = lines[index]
        if line in macros:
            lines = lines[:index] + macros[line] + lines[index + 1 :]
            index += len(macros[line]) - 1
        index += 1

    # Now parse the file
    curr_op = ""
    # dict of the form
    #  opName : { attributes: [], ... }
    ops = {}
    # To preserve parsing order for dependencies (for things like init_from)
    op_list = []

    for line in lines:
        if not len(line):
            continue
        if line[0] == "-":
            assert parse_state is OP
            attr = [_.strip() for _ in line[1:].split(":")]
            assert attr[0][0].isupper()
            if len(attr) == 2:  # attribute : type
                ops[curr_op]["attributes"].append((attr[0], attr[1]))
            elif len(attr) == 3:  # attribute : type
                ops[curr_op]["attributes"].append((attr[0], attr[1], attr[2]))
        else:
            op = [l.strip() for l in line.split(":")]
            assert len(op[0].split(" ")) == 1
            parse_state = OP
            curr_op = op[0]
            assert curr_op not in ops
            ops[curr_op] = {}
            op_list.append(curr_op)
            if len(op) > 1:
                ops[curr_op]["init_from"] = [op[1]]
            ops[curr_op]["attributes"] = []
    return ops, op_list


def gen_class(op, op_def):
    attributes = op_def["attributes"]
    attribute_args = []
    default_init = "NeuralNetOperator(NNKind::{op})".format(op=op)
    attribute_init = [default_init]
    attribute_declarations = []
    attribute_getters = []
    attribute_setters = []
    for attr in attributes:
        lower_name = attr[0][0].lower() + attr[0][1:]
        private_name = lower_name + "_"
        default_arg = "" if len(attr) < 3 else " = {}".format(attr[2])
        name = attr[0]
        t = attr[1]
        attr_arg = "{type} {lower_name}".format(
            type=t, lower_name=lower_name + default_arg
        )
        attr_init = "{private_name}({lower_name})".format(
            private_name=private_name, lower_name=lower_name)
        attr_declare = "{type} {private_name};".format(
            type=t, private_name=private_name)
        attr_get = dedent(
            """
              {type} get{name}() const {{
                return {private_name};
              }}
            """.format(
                type=t, name=name, private_name=private_name
            )
        )
        attr_set = dedent(
            """
              void set{name}({type} {lower_name}) {{
                {private_name} = {lower_name};
              }}
            """.format(
                type=t, name=name, private_name=private_name, lower_name=lower_name
            )
        )
        attribute_args.append(attr_arg)
        attribute_init.append(attr_init)
        attribute_declarations.append(attr_declare)
        attribute_getters.append(attr_get)
        attribute_setters.append(attr_set)

    extra_init = ""
    if "init_from" in op_def:
        for other_op in op_def["init_from"]:
            lower_other_op = other_op[0].lower() + other_op[1:]
            other_init = [default_init]
            for attr in attributes:
                lower_name = attr[0][0].lower() + attr[0][1:]
                private_name = lower_name + "_"
                other_init.append(
                    "{private_name}({other_op}.get{name}())".format(
                        name=attr[0], private_name=private_name, other_op=lower_other_op
                    )
                )
            init = dedent(
                """
                  {op}(const {other_op}& {lower_other_op}) :
                      {other_init} {{}}
                """.format(
                    op=op,
                    other_op=other_op,
                    lower_other_op=lower_other_op,
                    other_init=",\n      ".join(other_init),
                )
            )
            extra_init += init

    return dedent(
        """
        class {op} : public NeuralNetOperator {{
         public:
          {op}({attribute_args}) :
              {attribute_init} {{}}
          {extra_init}
          ~{op}() {{}}

          NOMNIGRAPH_DEFINE_NN_RTTI({op});
        {getters}{setters}
         private:
          {attribute_declarations}
        }};

        """.format(
            op=op,
            extra_init=extra_init,
            getters="".join(attribute_getters),
            setters="".join(attribute_setters),
            attribute_args=",\n".join(attribute_args),
            attribute_init=",\n".join(attribute_init),
            attribute_declarations="\n".join(attribute_declarations),
        )
    )


def gen_classes(ops, op_list):
    f = ""
    for op in op_list:
        f += gen_class(op, ops[op])
    return f


def gen_enum(op_list):
    return ",\n".join([op for op in op_list]) + "\n"


def gen_names(op_list):
    f = ""
    for op in op_list:
        f += dedent(
            """
            case NNKind::{name}:
                return \"{name}\";
            """.format(
                name=op
            )
        )
    return f


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate op files.")
    parser.add_argument("--install_dir", help="installation directory")
    parser.add_argument("--source_def", help="ops.def", action="append")
    args = parser.parse_args()
    install_dir = args.install_dir
    sources = args.source_def

    lines = []
    for source in sources:
        with open(source, "rb") as f:
            lines_tmp = f.readlines()
            lines += [l.strip().decode("utf-8") for l in lines_tmp]
    ops, op_list = parse_lines(lines)

    with open(install_dir + "/OpClasses.h", "wb") as f:
        f.write(gen_classes(ops, op_list).encode("utf-8"))
    with open(install_dir + "/OpNames.h", "wb") as f:
        f.write(gen_names(op_list).encode("utf-8"))
    with open(install_dir + "/OpEnum.h", "wb") as f:
        f.write(gen_enum(op_list).encode("utf-8"))

    try:
        cmd = ["clang-format", "-i", install_dir + "/OpClasses.h"]
        call(cmd)
        cmd = ["clang-format", "-i", install_dir + "/OpNames.h"]
        call(cmd)
        cmd = ["clang-format", "-i", install_dir + "/OpEnum.h"]
        call(cmd)
    except Exception:
        pass