mxfusion.inference.variational

Members

class mxfusion.inference.variational.StochasticVariationalInference(num_samples, model, posterior, observed)

Bases: mxfusion.inference.inference_alg.InferenceAlgorithm

The class of the Stochastic Variational Inference (SVI) algorithm.

Parameters:
  • num_samples (int) – the number of samples used in estimating the variational lower bound
  • model (Model) – the definition of the probabilistic model
  • posterior – the definition of the variational posterior of the probabilistic model
  • posterior – Posterior
  • observed ([Variable]) – A list of observed variables
posterior

return the variational posterior.

compute(F, data, parameters, constants)

The method for the computation of the inference algorithm

Parameters:
  • F (Python module) – the execution context (mxnet.ndarray or mxnet.symbol)
  • data ({Variable: mxnet.ndarray.ndarray.NDArray or mxnet.symbol.symbol.Symbol}) – the data variables for inference
  • parameters ({Variable: mxnet.ndarray.ndarray.NDArray or mxnet.symbol.symbol.Symbol}) – the parameters for inference
  • constants – the constants for inference
Returns:

the outcome of the inference algorithm

Return type:

mxnet.ndarray.ndarray.NDArray or mxnet.symbol.symbol.Symbol