Source code for qhbmlib.utils

# Copyright 2021 The QHBM Library Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities used across more than one file."""

import tensorflow as tf


[docs]class Squeeze(tf.keras.layers.Layer): """Wraps tf.squeeze in a Keras Layer.""" def __init__(self, axis=None): """Initializes a Squeeze layer. Args: axis: An optional list of ints. Defaults to []. If specified, only squeezes the dimensions listed. The dimension index starts at 0. It is an error to squeeze a dimension that is not 1. Must be in the range [-rank(input), rank(input)). Must be specified if input is a RaggedTensor. """ super().__init__() if axis is None: axis = [] self._axis = axis
[docs] def call(self, inputs): """Applies tf.squeeze to the inputs.""" return tf.squeeze(inputs, axis=self._axis)
[docs]def weighted_average(counts: tf.Tensor, values: tf.Tensor): """Returns the weighted average of input values. Subtensor `i` of `values` is multiplied by `counts[i]` resulting in a weighted version of values; the mean is then taken across the first dimension. Args: counts: Non-negative integers of shape [batch_size]. values: Floats of shape [batch_size, ...]. Returns: Tensor of shape [...] which is the weighted average. """ float_counts = tf.cast(counts, tf.float32) weighted_values = tf.einsum("i,i...->...", float_counts, values) return weighted_values / tf.reduce_sum(float_counts)
[docs]def unique_bitstrings_with_counts(bitstrings, out_idx=tf.dtypes.int32): """Extract the unique bitstrings in the given bitstring tensor. Args: bitstrings: 2-D `tf.Tensor`, interpreted as a list of bitstrings. out_idx: An optional `tf.DType` from: `tf.int32`, `tf.int64`. Defaults to `tf.int32`. Specifies the type of `count` output. Returns: y: 2-D `tf.Tensor` of same dtype as `bitstrings`, containing the unique 0-axis entries of `bitstrings`. idx: The index of each value of the input in the unique output `y`. count: 1-D `tf.Tensor` of dtype `out_idx` such that `count[i]` is the number of occurences of `y[i]` in `bitstrings`. """ y, idx, count = tf.raw_ops.UniqueWithCountsV2( x=bitstrings, axis=[0], out_idx=out_idx) return y, idx, count
[docs]def expand_unique_results(y, idx): """Inverse of unique_bitstrings_with_counts. Args: y: Values to pick according to `idx`. idx: The index at which to place each value of `y` in the output. Returns: expanded: `tf.Tensor` such that `expanded[i] == y[idx[i]]`. """ expanded = tf.gather(y, idx, axis=0) return expanded