Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tensorflow / purelib / tensorflow / python / autograph / converters / arg_defaults.py
Size: Mime:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Modifies the signature to allow resolving the value of default arguments.

Normally, function symbols are captured either in a function's globals or
closure. This is not true for default arguments, which are evaluated when the
function is defined:

    b = 1
    c = 2
    def f(a=b + 1):
      return a + c

In the above example, the namespace of the function would include `c = 2` but
not `b`.

If we were to naively generate a new function:

    def new_f(a=b + 1):
      return a + c

The generated code would fail to load unless we exposed a symbol `b`. Capturing
the closure of such an expression is difficult. However, we can capture the
default value of argument `a` with relative ease.

This converter replaces all default argument expressions with a constant so
that they don't cause loading to fail. This requires that the default values
are reset after loading the transformed function:

    def new_f(a=None):
      return a + c

    # ... later, after new_f was loaded ...
    new_f.__defaults__ = f.__defaults__

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import parser


class _Function(object):
  pass


class ArgDefaultsTransformer(converter.Base):
  """Transforms top level argument defaults."""

  def visit_Lambda(self, node):
    self.state[_Function].enter()
    node.args = self.visit(node.args)
    # Only the top level function is modified - no need to visit the children.
    self.state[_Function].exit()
    return node

  def visit_FunctionDef(self, node):
    self.state[_Function].enter()
    node.args = self.visit(node.args)
    # Only the top level function is modified - no need to visit the children.
    self.state[_Function].exit()
    return node

  def visit_arguments(self, node):
    if self.state[_Function].level > 2:
      return node

    for i in range(len(node.defaults)):
      node.defaults[i] = parser.parse_expression('None')

    for i, d in enumerate(node.kw_defaults):
      if d is not None:
        node.kw_defaults[i] = parser.parse_expression('None')

    # Only the top level function is modified - no need to visit the children.
    return node


def transform(node, ctx):
  """Transform function call to the compiled counterparts.

  Args:
    node: AST
    ctx: EntityContext
  Returns:
    A tuple (node, new_names):
        node: The transformed AST
        new_names: set(string), containing any newly-generated names
  """
  return ArgDefaultsTransformer(ctx).visit(node)