Repository URL to install this package:
|
Version:
1.14.0 ▾
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for computing common window functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export('signal.hann_window')
def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
"""Generate a [Hann window][hann].
Args:
window_length: A scalar `Tensor` indicating the window length to generate.
periodic: A bool `Tensor` indicating whether to generate a periodic or
symmetric window. Periodic windows are typically used for spectral
analysis while symmetric windows are typically used for digital
filter design.
dtype: The data type to produce. Must be a floating point type.
name: An optional name for the operation.
Returns:
A `Tensor` of shape `[window_length]` of type `dtype`.
Raises:
ValueError: If `dtype` is not a floating point type.
[hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
"""
return _raised_cosine_window(name, 'hann_window', window_length, periodic,
dtype, 0.5, 0.5)
@tf_export('signal.hamming_window')
def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
name=None):
"""Generate a [Hamming][hamming] window.
Args:
window_length: A scalar `Tensor` indicating the window length to generate.
periodic: A bool `Tensor` indicating whether to generate a periodic or
symmetric window. Periodic windows are typically used for spectral
analysis while symmetric windows are typically used for digital
filter design.
dtype: The data type to produce. Must be a floating point type.
name: An optional name for the operation.
Returns:
A `Tensor` of shape `[window_length]` of type `dtype`.
Raises:
ValueError: If `dtype` is not a floating point type.
[hamming]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
"""
return _raised_cosine_window(name, 'hamming_window', window_length, periodic,
dtype, 0.54, 0.46)
def _raised_cosine_window(name, default_name, window_length, periodic,
dtype, a, b):
"""Helper function for computing a raised cosine window.
Args:
name: Name to use for the scope.
default_name: Default name to use for the scope.
window_length: A scalar `Tensor` or integer indicating the window length.
periodic: A bool `Tensor` indicating whether to generate a periodic or
symmetric window.
dtype: A floating point `DType`.
a: The alpha parameter to the raised cosine window.
b: The beta parameter to the raised cosine window.
Returns:
A `Tensor` of shape `[window_length]` of type `dtype`.
Raises:
ValueError: If `dtype` is not a floating point type or `window_length` is
not scalar or `periodic` is not scalar.
"""
if not dtype.is_floating:
raise ValueError('dtype must be a floating point type. Found %s' % dtype)
with ops.name_scope(name, default_name, [window_length, periodic]):
window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32,
name='window_length')
window_length.shape.assert_has_rank(0)
window_length_const = tensor_util.constant_value(window_length)
if window_length_const == 1:
return array_ops.ones([1], dtype=dtype)
periodic = math_ops.cast(
ops.convert_to_tensor(periodic, dtype=dtypes.bool, name='periodic'),
dtypes.int32)
periodic.shape.assert_has_rank(0)
even = 1 - math_ops.mod(window_length, 2)
n = math_ops.cast(window_length + periodic * even - 1, dtype=dtype)
count = math_ops.cast(math_ops.range(window_length), dtype)
cos_arg = constant_op.constant(2 * np.pi, dtype=dtype) * count / n
if window_length_const is not None:
return math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype)
return control_flow_ops.cond(
math_ops.equal(window_length, 1),
lambda: array_ops.ones([1], dtype=dtype),
lambda: math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype))