logistic-regression - 为什么在这个逻辑回归示例中 Pymc3 ADVI 比 M

我知道 ADVI/MCMC 之间的数学差异,但我试图了解使用其中一个的实际含义。我正在对以这种方式创建的数据运行一个非常简单的逻辑回归示例:

import pandas as pd
import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np

def logistic(x, b, noise=None):
    L = x.T.dot(b)
    if noise is not None:
        L = L+noise
    return 1/(1+np.exp(-L))

x1 = np.linspace(-10., 10, 10000)
x2 = np.linspace(0., 20, 10000)
bias = np.ones(len(x1))
X = np.vstack([x1,x2,bias]) # Add intercept
B =  [-10., 2., 1.] # Sigmoid params for X + intercept

# Noisy mean
pnoisy = logistic(X, B, noise=np.random.normal(loc=0., scale=0., size=len(x1)))
# dichotomize pnoisy -- sample 0/1 with probability pnoisy
y = np.random.binomial(1., pnoisy)

我像这样运行 ADVI:
with pm.Model() as model: 
    # Define priors
    intercept = pm.Normal('Intercept', 0, sd=10)
    x1_coef = pm.Normal('x1', 0, sd=10)
    x2_coef = pm.Normal('x2', 0, sd=10)

    # Define likelihood
    likelihood = pm.Bernoulli('y',                  
           pm.math.sigmoid(intercept+x1_coef*X[0]+x2_coef*X[1]),
                          observed=y)
    approx = pm.fit(90000, method='advi')

不幸的是,无论我增加多少采样,ADVI 似乎都无法恢复我定义的原始 beta [-10., 2., 1.],而 MCMC 工作正常(如下图)



谢谢您的帮助!

最佳答案

这是个有趣的问题!默认'advi'在 PyMC3 中是平均场变分推理,它在捕获相关性方面做得并不好。事实证明,您建立的模型有一个有趣的相关结构,可以通过以下方式看出:

import arviz as az

az.plot_pair(trace, figsize=(5, 5))



PyMC3 有一个内置的收敛检查器 - 运行优化太长或太短都会导致有趣的结果:
from pymc3.variational.callbacks import CheckParametersConvergence

with model:
    fit = pm.fit(100_000, method='advi', callbacks=[CheckParametersConvergence()])

draws = fit.sample(2_000)

这对我来说在大约 60,000 次迭代后停止。现在我们可以检查相关性并看到,正如预期的那样,ADVI 拟合轴对齐高斯:
az.plot_pair(draws, figsize=(5, 5))



最后,我们可以比较 NUTS 和(平均场)ADVI 的拟合:
az.plot_forest([draws, trace])



请注意,ADVI 低估了方差,但与每个参数的均值相当接近。此外,您可以设置 method='fullrank_advi'以更好地捕捉您看到的相关性。

(注意: arviz 即将成为 PyMC3 的绘图库)

关于logistic-regression - 为什么在这个逻辑回归示例中 Pymc3 ADVI 比 MCMC 差?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52558826/

相关文章:

elm - HTML 可选属性

janusgraph - 如何在 Janusgraph 中获取索引键列表?

maven - 将 Artifact 安装到特定的远程 Maven 存储库

prolog - '/1' 在 Prolog 中代表什么?

python-3.x - pyspark中的异常值检测

ios - 如何在 SwiftUI 中全屏显示 View ?

shell - 画一棵圣诞树

deployment - 如何在 Kubernetes 中相同部署的两个 Pod 中使环境变量不同?

github - 如何从 GitHub 评论为我的拉取请求触发 Travis 重建?

scala - 如何在 Scala 中实现 Python 的 issuperset()