# Copyright 2021 The QHBM Library Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools for inference on energy functions represented by a BitstringEnergy."""
import abc
import functools
import itertools
from typing import Union
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from qhbmlib.models import energy
from qhbmlib import utils
def preface_inference(f):
"""Wraps given function with things to run before every inference call.
Args:
f: The method of `EnergyInference` to wrap.
Returns:
wrapper: The wrapped function.
"""
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
self._preface_inference() # pylint: disable=protected-access
return f(self, *args, **kwargs)
return wrapper
[docs]class EnergyInferenceBase(tf.keras.layers.Layer, abc.ABC):
r"""Defines the interface for inference on BitstringEnergy objects.
Let $E$ be the energy function defined by a given `BitstringEnergy`, and let
$X$ be the set of bitstrings in the domain of $E$. Associated with $E$ is
a probability distribution
$$p(x) = \frac{e^{-E(x)}}{\sum_{y\in X} e^{-E(y)}},$$
which we call the Energy Based Model (EBM) associated with $E$. Inference
in this class means estimating quantities of interest relative to the EBM.
"""
def __init__(self,
input_energy: energy.BitstringEnergy,
initial_seed: Union[None, tf.Tensor] = None,
name: Union[None, str] = None):
"""Initializes an EnergyInferenceBase.
Args:
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
initial_seed: PRNG seed; see tfp.random.sanitize_seed for details. This
seed will be used in the `sample` method. If None, the seed is updated
after every inference call. Otherwise, the seed is fixed.
name: Optional name for the model.
"""
super().__init__(name=name)
self._energy = input_energy
self._energy.build([None, self._energy.num_bits])
self._tracked_variables = input_energy.variables
if len(self._tracked_variables) == 0:
self._checkpoint = False
else:
self._tracked_variables_checkpoint = [
tf.Variable(v.read_value(), trainable=False)
for v in self._tracked_variables
]
self._checkpoint = True
if initial_seed is None:
self._update_seed = tf.Variable(True, trainable=False)
else:
self._update_seed = tf.Variable(False, trainable=False)
self._seed = tf.Variable(
tfp.random.sanitize_seed(initial_seed), trainable=False)
self._first_inference = tf.Variable(True, trainable=False)
@property
def energy(self):
"""The energy function which sets the probabilities for this EBM."""
return self._energy
@property
def seed(self):
"""Current TFP compatible seed controlling sampling behavior.
PRNG seed; see tfp.random.sanitize_seed for details. This seed will be used
in the `sample` method. If None, the seed is updated after every inference
call. Otherwise, the seed is fixed.
"""
return self._seed
@seed.setter
def seed(self, initial_seed: Union[None, tf.Tensor]):
"""Sets a new value of the random seed.
Args:
initial_seed: see `self.seed` for details.
"""
if initial_seed is None:
self._update_seed.assign(True)
else:
self._update_seed.assign(False)
self._seed.assign(tfp.random.sanitize_seed(initial_seed))
@property
def variables_updated(self):
"""Returns True if tracked variables do not have the checkpointed values."""
if self._checkpoint:
variables_not_equal_list = tf.nest.map_structure(
lambda v, vc: tf.math.reduce_any(tf.math.not_equal(v, vc)),
self._tracked_variables, self._tracked_variables_checkpoint)
return tf.math.reduce_any(tf.stack(variables_not_equal_list))
else:
return False
def _checkpoint_variables(self):
"""Checkpoints the currently tracked variables."""
if self._checkpoint:
tf.nest.map_structure(lambda v, vc: vc.assign(v), self._tracked_variables,
self._tracked_variables_checkpoint)
def _preface_inference(self):
"""Things all energy inference methods do before proceeding.
Called by `preface_inference` before the wrapped inference method.
Currently includes:
- run `self._ready_inference` if this is first call of a wrapped function
- change the seed if not set by the user during initialization
- run `self._ready_inference` if tracked energy parameters changed
Note: subclasses should take care to call the superclass method.
"""
if self._first_inference:
self._checkpoint_variables()
self._ready_inference()
self._first_inference.assign(False)
if self._update_seed:
new_seed, _ = tfp.random.split_seed(self.seed)
self._seed.assign(new_seed)
if self.variables_updated:
self._checkpoint_variables()
self._ready_inference()
@abc.abstractmethod
def _ready_inference(self):
"""Performs computations common to all inference methods.
Contains inference code that must be run first if the variables of
`self.energy` have been updated since the last time inference was performed.
"""
[docs] @preface_inference
def call(self, inputs, *args, **kwargs):
"""Calls this layer on the given inputs."""
return self._call(inputs, *args, **kwargs)
[docs] @preface_inference
def entropy(self):
"""Returns an estimate of the entropy."""
return self._entropy()
[docs] @preface_inference
def expectation(self, function):
"""Returns an estimate of the expectation value of the given function.
Args:
function: Mapping from a 2D tensor of bitstrings to a possibly nested
structure. The structure must have atomic elements all of which are
float tensors with the same batch size as the input bitstrings.
"""
return self._expectation(function)
[docs] @preface_inference
def log_partition(self):
"""Returns an estimate of the log partition function."""
return self._log_partition()
[docs] @preface_inference
def sample(self, num_samples: int):
"""Returns samples from the EBM corresponding to `self.energy`.
Args:
num_samples: Number of samples to draw from the EBM.
"""
return self._sample(num_samples)
@abc.abstractmethod
def _call(self, inputs, *args, **kwargs):
"""Default implementation wrapped by `self.call`."""
raise NotImplementedError()
@abc.abstractmethod
def _entropy(self):
"""Default implementation wrapped by `self.entropy`."""
raise NotImplementedError()
@abc.abstractmethod
def _expectation(self, function):
"""Default implementation wrapped by `self.expectation`."""
raise NotImplementedError()
@abc.abstractmethod
def _log_partition(self):
"""Default implementation wrapped by `self.log_partition`."""
raise NotImplementedError()
@abc.abstractmethod
def _sample(self, num_samples: int):
"""Default implementation wrapped by `self.sample`."""
raise NotImplementedError()
[docs]class EnergyInference(EnergyInferenceBase):
"""Provides some default method implementations."""
def __init__(self,
input_energy: energy.BitstringEnergy,
num_expectation_samples: int,
initial_seed: Union[None, tf.Tensor] = None,
name: Union[None, str] = None):
"""Initializes an EnergyInference.
Args:
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
num_expectation_samples: Number of samples to draw and use for estimating
the expectation value.
initial_seed: PRNG seed; see tfp.random.sanitize_seed for details. This
seed will be used in the `sample` method. If None, the seed is updated
after every inference call. Otherwise, the seed is fixed.
name: Optional name for the model.
"""
super().__init__(input_energy, initial_seed, name)
self.num_expectation_samples = num_expectation_samples
def _entropy(self):
"""Default implementation wrapped by `self.entropy`."""
return self.expectation(self.energy) + self.log_partition()
def _expectation(self, function):
"""Default implementation wrapped by `self.expectation`.
Estimates an expectation value using sample averaging.
"""
@tf.custom_gradient
def _inner_expectation():
"""Enables derivatives."""
samples = tf.stop_gradient(self.sample(self.num_expectation_samples))
bitstrings, _, counts = utils.unique_bitstrings_with_counts(samples)
# TODO(#157): try to parameterize the persistence.
with tf.GradientTape() as values_tape:
# Adds variables in `self.energy` to `variables` argument of `grad_fn`.
values_tape.watch(self.energy.trainable_variables)
values = function(bitstrings)
average_of_values = tf.nest.map_structure(
lambda x: utils.weighted_average(counts, x), values)
def grad_fn(*upstream, variables):
"""See equation A5 in the QHBM paper appendix for details.
# TODO(#119): confirm equation number.
"""
function_grads = values_tape.gradient(
average_of_values,
variables,
output_gradients=upstream,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
flat_upstream = tf.nest.flatten(upstream)
flat_values = tf.nest.flatten(values)
combined_flat = tf.nest.map_structure(lambda x, y: x * y, flat_upstream,
flat_values)
combined_flat_sum = tf.nest.map_structure(
lambda x: tf.map_fn(tf.reduce_sum, x), combined_flat)
combined_sum = tf.reduce_sum(tf.stack(combined_flat_sum), 0)
average_of_combined_sum = utils.weighted_average(counts, combined_sum)
# Compute grad E terms.
with tf.GradientTape() as tape:
energies = self.energy(bitstrings)
energies_grads = tape.jacobian(
energies,
variables,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
average_of_energies_grads = tf.nest.map_structure(
lambda x: utils.weighted_average(counts, x), energies_grads)
product_of_averages = tf.nest.map_structure(
lambda x: x * average_of_combined_sum, average_of_energies_grads)
products = tf.nest.map_structure(
lambda x: tf.einsum("i...,i->i...", x, combined_sum),
energies_grads)
average_of_products = tf.nest.map_structure(
lambda x: utils.weighted_average(counts, x), products)
# Note: upstream gradient is already a coefficient in poa, aop, and fg.
return tuple(), [
poa - aop + fg for poa, aop, fg in zip(
product_of_averages, average_of_products, function_grads)
]
return average_of_values, grad_fn
return _inner_expectation()
def _log_partition(self):
"""Default implementation wrapped by `self.log_partition`."""
@tf.custom_gradient
def _inner_log_partition():
"""Wraps forward pass computaton."""
result = self._log_partition_forward_pass()
# Adds variables in `self.energy` to `variables` argument of `grad_fn`.
_ = [tf.identity(x) for x in self.energy.trainable_variables]
grad_fn = self._log_partition_grad_generator()
return result, grad_fn
return _inner_log_partition()
def _log_partition_forward_pass(self):
r"""Returns approximation to the log partition function.
The calculation uses the uniform distribution to approximate the
partition function using Monte Carlo integration. See section 11.2 of [1]
for details on Monte Carlo integration.
TODO (#216): decrease variance of this estimator.
Given an energy function $E$, the associated partition function is defined
$$
Z = \sum_x e^{-E(x)},
$$
where $x$ ranges over the domain of $E$. We do not want to compute this sum
directly because it is over a set exponentially large in the number of bits.
Instead, we can use Monte Carlo integration. To support this, consider
rewriting the sum as an expectation value with respect to the uniform
distribution $u(x)$:
$$
Z = \sum_x e^{-E(x)}
= \sum_x u(x) 2^n e^{-E(x)}
= 2^n \mathbb{E}\left[e^{-E(\cdot)}\right],
$$
where $n$ is the number of bits in each uniform sample. Now draw $N_s$
samples from the uniform distribution. Then we can approximate the
expectation value as (equation 11.2)
$$
2^n \mathbb{E}\left[e^{-E(\cdot)}\right]
\approx\frac{2^n}{N_s}\sum_{i=1}^{N_s} e^{-E(x_i)}.
$$
Next, what we are really interested in is the logarithm of the partition
function. Then we have
$$
\log Z
\approx \log \left(\frac{2^n}{N_s}\sum_{i=1}^{N_s} e^{-E(x_i)}\right)
= n \log 2 - \log N_s + \log \sum_{i=1}^{N_s} e^{-E(x_i)}.
$$
#### References
[1]: Murphy, Kevin P. (2023).
Probabilistic Machine Learning: Advanced Topics.
MIT Press.
"""
n = self.energy.num_bits
n_s = self.num_expectation_samples
# Sample from the uniform distribution
samples = tfp.distributions.Bernoulli(logits=tf.zeros([n])).sample(n_s)
energies = self.energy(samples)
return n * tf.math.log(2.0) - tf.math.log(tf.cast(
n_s, tf.float32)) + tf.math.reduce_logsumexp(-1.0 * energies)
def _log_partition_grad_generator(self):
"""Returns default estimator for the log partition function derivative."""
def grad_fn(upstream, variables):
"""See equation C2 in the appendix. TODO(#119)"""
samples = self.sample(self.num_expectation_samples)
unique_samples, _, counts = utils.unique_bitstrings_with_counts(samples)
with tf.GradientTape() as tape:
unique_energies = self.energy(unique_samples)
unique_jacobians = tape.jacobian(
unique_energies,
variables,
unconnected_gradients=tf.UnconnectedGradients.ZERO)
average_jacobians = tf.nest.map_structure(
lambda j: utils.weighted_average(counts, j), unique_jacobians)
return tuple(), [upstream * (-1.0 * aj) for aj in average_jacobians]
return grad_fn
[docs]class AnalyticEnergyInference(EnergyInference):
"""Uses an explicit categorical distribution to implement parent functions."""
def __init__(self,
input_energy: energy.BitstringEnergy,
num_expectation_samples: int,
initial_seed: Union[None, tf.Tensor] = None,
name: Union[None, str] = None):
"""Initializes an AnalyticEnergyInference.
Internally, this class saves all possible bitstrings as a tensor, whose
energies are calculated relative to an input energy function for sampling
and other inference tasks.
Args:
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
num_expectation_samples: Number of samples to draw and use for estimating
the expectation value.
initial_seed: PRNG seed; see tfp.random.sanitize_seed for details. This
seed will be used in the `sample` method. If None, the seed is updated
after every inference call. Otherwise, the seed is fixed.
name: Optional name for the model.
"""
super().__init__(input_energy, num_expectation_samples, initial_seed, name)
self._all_bitstrings = tf.constant(
list(itertools.product([0, 1], repeat=input_energy.num_bits)),
dtype=tf.int8)
self._logits_variable = tf.Variable(
-input_energy(self.all_bitstrings), trainable=False)
self._distribution = tfd.Categorical(logits=self._logits_variable)
@property
def all_bitstrings(self):
"""Returns every bitstring."""
return self._all_bitstrings
@property
def all_energies(self):
"""Returns the energy of every bitstring."""
return self.energy(self.all_bitstrings)
@property
def distribution(self):
"""Categorical distribution set during `self._ready_inference`."""
return self._distribution
def _ready_inference(self):
"""See base class docstring."""
self._logits_variable.assign(-self.all_energies)
def _call(self, inputs, *args, **kwargs):
"""See base class docstring."""
if inputs is None:
return self.distribution
else:
return self.sample(inputs)
def _entropy(self):
"""See base class docstring."""
return self.distribution.entropy()
def _log_partition_forward_pass(self):
"""See base class docstring."""
# TODO(#115)
return tf.reduce_logsumexp(self.distribution.logits_parameter())
def _sample(self, num_samples: int):
"""See base class docstring."""
return tf.gather(
self.all_bitstrings,
self.distribution.sample(num_samples, seed=self.seed),
axis=0)
[docs]class BernoulliEnergyInference(EnergyInference):
"""Manages inference for a Bernoulli defined by spin energies."""
def __init__(self,
input_energy: energy.BernoulliEnergy,
num_expectation_samples: int,
initial_seed: Union[None, tf.Tensor] = None,
name: Union[None, str] = None):
"""Initializes a BernoulliEnergyInference.
Args:
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
num_expectation_samples: Number of samples to draw and use for estimating
the expectation value.
initial_seed: PRNG seed; see tfp.random.sanitize_seed for details. This
seed will be used in the `sample` method. If None, the seed is updated
after every inference call. Otherwise, the seed is fixed.
name: Optional name for the model.
"""
super().__init__(input_energy, num_expectation_samples, initial_seed, name)
self._logits_variable = tf.Variable(input_energy.logits, trainable=False)
self._distribution = tfd.Bernoulli(
logits=self._logits_variable, dtype=tf.int8)
@property
def distribution(self):
"""Bernoulli distribution set during `self._ready_inference`."""
return self._distribution
def _ready_inference(self):
"""See base class docstring."""
self._logits_variable.assign(self.energy.logits)
def _call(self, inputs, *args, **kwargs):
"""See base class docstring."""
if inputs is None:
return self.distribution
else:
return self.sample(inputs)
def _entropy(self):
"""Returns the exact entropy.
The total entropy of a set of spins is the sum of each individual spin's
entropies.
"""
return tf.reduce_sum(self.distribution.entropy())
def _log_partition_forward_pass(self):
r"""Returns the exact log partition function.
For a single spin of energy $\theta$, the partition function is
$$Z_\theta = \exp(\theta) + \exp(-\theta).$$
Since each spin is independent, the total log partition function is
the sum of the individual spin log partition functions.
"""
thetas = 0.5 * self.energy.logits
single_log_partitions = tf.math.log(
tf.math.exp(thetas) + tf.math.exp(-thetas))
return tf.math.reduce_sum(single_log_partitions)
def _sample(self, num_samples: int):
"""See base class docstring"""
return self.distribution.sample(num_samples, seed=self.seed)
class GibbsWithGradientsKernel(tfp.mcmc.TransitionKernel):
"""Implements the Gibbs With Gradients update rule.
See Algorithm 1 of
https://arxiv.org/abs/2102.04509v2
for details. Here we summarize the motivations for Gibbs With Gradients
given in the paper at the link.
Consider Gibbs sampling, a special case of Metropolis-Hastings for discrete
distributions. For a distribution over bitstrings, the Gibbs sampler works as
follows: given bitstring x, propose an index, say i, drawn from some proposal
distribution q; form the conditional distribution of x[i] given all other
entries; sample a new bit b from the conditional and set x[i] = b.
The core difficulty with this simple strategy is intelligently choosing the
index to update. This means the proposal distribution q(i) should have more
mass on indices that are more likely to require updates. For example, in
sampling from a model of the MNIST dataset one can take advantage of the fact
that almost all dimensions represent the background and therefore are unlikely
to change sample-to-sample.
A previously known way to achieve this behavior is to use a proposal
distribution which is "locally-informed". The local information used is the
difference in log probability between the current state and states in a
Hamming ball around the current state, as in equation 2. This focuses
probability on the indices most likely to lower the energy.
However, locally-informed proposals scale poorly in the data dimensionality:
they require evaluating the log probability at every point in the Hamming ball
around the current state. The authors propose a method to approximate these
evaluations, which requires only a single evaluation of the log probability
function as well as the derivative. Since the method uses derivative
information, the authors name it Gibbs With Gradients.
This transition kernel implements the Gibbs With Gradients update rule.
"""
def __init__(self, input_energy):
"""Initializes a GibbsWithGradientsKernel.
Args:
input_energy: The parameterized energy function which helps define the
acceptance probabilities of the Markov chain.
"""
self._energy = input_energy
self._num_bits = input_energy.num_bits
self._parameters = dict(input_energy=input_energy)
# q(i | x) in Algorithm 1
self._index_proposal_probs = tf.Variable(
[0.0] * self._num_bits, trainable=False)
self._index_proposal_dist = tfp.distributions.Categorical(
probs=self._index_proposal_probs)
self._eye_bool = tf.eye(self._num_bits, dtype=tf.bool)
def _get_index_proposal_probs(self, x):
"""Returns the value of equation 6 of the paper.
Given current state x, the locally-informed proposal evaluates f(x') - f(x)
for each x' in a Hamming ball of radius 1 around x, written H(x). Let
d(x) = [f(x') - f(x) for x' in H(x)]. The locally-informed conditional
distribution is then q(i | x) = C * e^(d(x) / T) for some temperature T,
where C is the normalization constant (T=2 is optimal under a standard
criterion).
The insight of the paper is that d(x) can be approximated using a Taylor
series, yielding d(x) ~ -(2x - 1) * df(x)/dx. This function computes the
locally-informed conditional probabilities q(i | x) using the approximation.
Note that in terms of the unnormalized log probability of the state, f(x),
used in the paper, the energy of the state is: E(x) = -f(x). So we have
d(x) ~ (2x - 1) * dE(x)/dx
Args:
x: 1D tensor which is the current state of the chain.
Returns:
The conditional probabilities q(i | x) given in equation 6.
"""
x_float = tf.cast(x, tf.float32)
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(x_float)
current_energy = tf.squeeze(self._energy(tf.expand_dims(x_float, 0)))
# f(x) = -E(x)
f_grad = -1.0 * tape.gradient(current_energy, x_float)
# Equation 3
approx_energy_diff = -(2 * tf.cast(x, tf.float32) - 1) * f_grad
return tf.nn.softmax(approx_energy_diff / 2.0)
def one_step(self, current_state, previous_kernel_results):
"""Takes one step of the TransitionKernel.
Follows algorithm 1 of
https://arxiv.org/abs/2102.04509v2
Args:
current_state: 1D `tf.Tensor` which is the current chain state.
previous_kernel_results: Required but unused argument.
Returns:
next_state: 1D `Tensor` which is the next state in the chain.
kernel_results: Empty list.
"""
del previous_kernel_results
index_proposal_probs = self._get_index_proposal_probs(current_state)
self._index_proposal_probs.assign(index_proposal_probs)
proposed_i = self._index_proposal_dist.sample()
flip_vec = self._eye_bool[proposed_i]
x_prime = tf.cast(
tf.math.logical_xor(tf.cast(current_state, tf.bool), flip_vec), tf.int8)
index_proposal_probs_prime = self._get_index_proposal_probs(x_prime)
q_ratio = index_proposal_probs_prime[proposed_i] / index_proposal_probs[
proposed_i]
energies = self._energy(tf.stack([x_prime, current_state]))
exp_f = tf.math.exp(-energies[0] + energies[1])
accept_prob = tf.math.minimum(exp_f * q_ratio, 1.0)
roll = tf.random.uniform([], dtype=tf.float32)
accept = tf.math.less_equal(roll, accept_prob)
flip_vec = tf.math.logical_and(accept, self._eye_bool[proposed_i])
next_state = tf.cast(
tf.math.logical_xor(tf.cast(current_state, tf.bool), flip_vec), tf.int8)
kernel_results = []
return next_state, kernel_results
@property
def is_calibrated(self):
"""Returns `True` if Markov chain converges to specified distribution."""
return True
def bootstrap_results(self, init_state):
"""Returns an object with the same type as returned by `one_step(...)[1]`.
Args:
init_state: 1D `tf.Tensor` which is the initial chain state.
Returns:
kernel_results: Empty list.
"""
del init_state
return []
[docs]class GibbsWithGradientsInference(EnergyInference):
"""Manages inference using a Gibbs With Gradients kernel."""
def __init__(self,
input_energy: energy.BitstringEnergy,
num_expectation_samples: int,
num_burnin_samples: int,
name: Union[None, str] = None):
"""Initializes a GibbsWithGradientsInference.
Args:
input_energy: The parameterized energy function which defines this
distribution via the equations of an energy based model. This class
assumes that all parameters of `energy` are `tf.Variable`s and that
they are all returned by `energy.variables`.
num_expectation_samples: Number of samples to draw and use for estimating
the expectation value.
num_burnin_samples: Number of samples to discard when letting the chain
equilibrate after updating the parameters of `input_energy`.
name: Optional name for the model.
"""
super().__init__(input_energy, num_expectation_samples, name=name)
self._kernel = GibbsWithGradientsKernel(input_energy)
self._chain_state = tf.Variable(
tfp.distributions.Bernoulli(
probs=[0.5] * self.energy.num_bits, dtype=tf.int8).sample(),
trainable=False)
self.num_burnin_samples = num_burnin_samples
def _ready_inference(self):
"""See base class docstring.
Runs the chain for a number of steps without saving the results, in order
to better reach equilibrium before recording samples.
"""
state = self._chain_state.read_value()
for _ in tf.range(self.num_burnin_samples):
state, _ = self._kernel.one_step(state, [])
self._chain_state.assign(state)
def _call(self, inputs, *args, **kwargs):
"""See base class docstring."""
return self.sample(inputs)
def _sample(self, num_samples: int):
"""See base class docstring.
The kernels are repeatedly called to traverse chains of samples.
"""
ta = tf.TensorArray(tf.int8, size=num_samples)
state = self._chain_state.read_value()
for i in tf.range(num_samples):
state, _ = self._kernel.one_step(state, [])
ta = ta.write(i, state)
self._chain_state.assign(state)
return ta.stack()