Source code for qhbmlib.models.energy_utils

# Copyright 2021 The QHBM Library Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for the energy_model module."""

import itertools
from typing import List

import tensorflow as tf


def check_bits(bits: List[int]) -> List[int]:
  """Confirms the input is a valid bit index list."""
  if len(set(bits)) != len(bits):
    raise ValueError("All entries of `bits` must be unique.")
  return bits


def check_order(order: int) -> int:
  """Confirms the input is a valid parity order."""
  if not isinstance(order, int):
    raise TypeError("`order` must be an integer.")
  if order <= 0:
    raise ValueError("`order` must be greater than zero.")
  return order


[docs]class SpinsFromBitstrings(tf.keras.layers.Layer): """Simple layer taking bits to spins.""" def __init__(self): """Initializes a SpinsFromBitstrings.""" super().__init__(trainable=False)
[docs] def call(self, inputs): """Returns the spins corresponding to the input bitstrings. Note that this maps |0> -> +1 and |1> -> -1. This is in accordance with the usual interpretation of the Bloch sphere. """ return tf.cast(1 - 2 * inputs, tf.float32)
[docs]class VariableDot(tf.keras.layers.Layer): """Utility layer for dotting input with a same-sized variable.""" def __init__(self, initializer: tf.keras.initializers.Initializer = tf.keras .initializers.RandomUniform()): """Initializes a VariableDot layer. Args: initializer: A `tf.keras.initializers.Initializer` which specifies how to initialize the values of the parameters. """ super().__init__() self._initializer = initializer
[docs] def build(self, input_shape): """Initializes the internal variables.""" self.kernel = self.add_weight( name="kernel", shape=(input_shape[-1],), dtype=tf.float32, initializer=self._initializer, trainable=True)
[docs] def call(self, inputs): """Returns the dot product between the inputs and this layer's variables.""" return tf.reduce_sum(inputs * self.kernel, -1)
[docs]class Parity(tf.keras.layers.Layer): """Computes the parities of input spins.""" def __init__(self, bits: List[int], order: int): """Initializes a Parity layer. Args: bits: Unique labels for the bits on which this distribution is supported. order: Maximum size of bit groups to take the parity of. """ super().__init__(trainable=False) bits = check_bits(bits) order = check_order(order) indices_list = [] for i in range(1, order + 1): combos = itertools.combinations(range(len(bits)), i) indices_list.extend(list(combos)) self.indices = tf.ragged.stack(indices_list) self.num_terms = len(indices_list)
[docs] def call(self, inputs): """Returns a batch of parities corresponding to the input bitstrings.""" parities_t = tf.zeros([self.num_terms, tf.shape(inputs)[0]]) for i in tf.range(self.num_terms): parity = tf.reduce_prod(tf.gather(inputs, self.indices[i], axis=-1), -1) parities_t = tf.tensor_scatter_nd_update(parities_t, [[i]], [parity]) return tf.transpose(parities_t)