ころがる狸

ころがる狸のデータ解析ブログ

【状態空間モデル】PyStanとpykalmanでダウ平均株価予測

こんにちは。ゴールデンウィーク3日目です。緊急事態宣言が5月末まで延長しそうです。家に籠って勉強なりゲームなりをしています。

今日は、状態空間モデルを取り上げます。状態空間モデルでは、実際の観測値とその背後にある真の状態を分けて考えます。真の状態は時間とともに変化しますが、私たち観測者にはその状態が見えません。観測者が手にすることができるのは観測値のみで、これに基づいて真の状態を推定します。もっとも素朴なモデルでは、真の状態における1つの時間ステップでの変化は微小であると想定したり、観測されるのは真の状態にノイズがのったものであるとする仮定を置いたりします。このような状態空間モデルのイメージ図として以下のような図が用いられることが多いです。真の状態が時々刻々と推移しており、私たちが観測する値はそこから派生したものであると見なします。

f:id:Dajiro:20200504142842p:plain
状態空間モデルのイメージ
状態空間モデルを用いた予測の方法には何通りか方法がありますが、ここではカルマンフィルタを用いた方法とMCMC(マルコフ連鎖モンテカルロ法)を用いたベイズ統計モデリングを試してみたいと思います。参考書としては以下の2冊を参照しました。
こちらはpythonベースで書かれており、pykalmanによるカルマンフィルタの実装が記載されており参考になります。こちらはR言語で記載されていますが、ベイズ統計モデリングでの悩みどころか解析のテクニックが随所で触れられているのでRとStanを知らなくてもお薦めです。現に私はRでのプログラミングは未経験です。
StanとRでベイズ統計モデリング (Wonderful R)

StanとRでベイズ統計モデリング (Wonderful R)

目次はこちらです。

解析対象データ

Yahoo!からダウ平均株価のcsvファイルを取得して解析しました。日ごと、週ごと、月ごとでダウンロードできますがここでは月ごとの1985年5月から2020年5月までの時系列データとしました。

カルマンフィルタ

カルマンフィルタは状態を推定する仕組みです。観測値を用いて、状態の推定値の精度を逐次的に補正していきます。ここで、状態が線形的に変化しノイズが正規分布に従うと仮定した線形ガウス型モデルの一般的な数式を記載しておきます。

真の状態(システム):\it{x}_{t} = F_{t}\it{x}_{t-1} + G_{t}v_{t},    v_{t} \sim \rm{N}(0, Q_{t})
観測値:\it{y}_{t} = H_{t}\it{x}_{t} + w_{t},     w_{t} \sim \rm{N}(0, R_{t})

これは真の状態(システム)の推移と、真の状態からの観測値の生成をモデリングしたものです。上式のFGHをユーザーが設定する必要があります。以下の解析では、観測値は真の状態にノイズがのったものとして観測され、真の状態は時刻の差分の変化が微小(システムノイズに等しい)というモデルを作ってみたいと思います。数式で書くと以下のようになります。

真の状態(システム):x_{\rm{t}} - x_{\rm{t-1}} = x_{\rm{t-1}} - x_{\rm{t-2}} + v_{\rm{t}}
観測値:y_{\rm{t}} = x_{\rm{t}} + w_{\rm{t}}

このコードは以下のようになります。

import pandas as pd
import matplotlib.pyplot as plt
import japanize_matplotlib
import numpy as np
from pykalman import KalmanFilter

def FGHset(n_dim_trend, n_dim_obs=1, Q_sigma2=10):
    """
    KalmanFilterに用いる行列の初期化を行う。
    
    Parameters
    ----------
    n_dim_trend : int
        トレンドの次数(階差の次数)
    n_dim_obs : int
        観測値の次元数
    Q_sigma2 : float
        システムノイズの分散
    Returns
    -------
    n_dim_state: int
        状態の次元数
    F : ndarray
        推移行列
    H : ndarray
        観測行列
    Q : ndarray
        システムモデルの分散共分散行列
    """
    
    n_dim_Q = 1
    n_dim_state = n_dim_trend
    
    G = np.zeros((n_dim_state, n_dim_Q)) #システムノイズを定義
    F = np.zeros((n_dim_state, n_dim_state)) #推移行列
    H = np.zeros((n_dim_obs, n_dim_state)) #観測行列
    Q = np.eye(n_dim_Q) * Q_sigma2 #システムノイズの分散
    
    G[0, 0] = 1
    H[0, 0] = 1
    
    if n_dim_trend == 1:
        F[0, 0] = 1
    elif n_dim_trend == 2:
        F[0, 0] = 2
        F[0, 1] = -1
        F[1, 0] = 1
    elif n_dim_trend == 3:
        F[0, 0] = 3
        F[0, 1] = -3
        F[0, 2] = 1
        F[1, 0] = 1
        F[2, 1] = 1
    
    Q = G.dot(Q).dot(G.T)
    
    return n_dim_state, F, H, Q

n_dim_obs = 1 #観測値の次元数
n_dim_trend = 2 #トレンドの次数(階差の次数)

n_dim_state, F, H, Q = FGHset(n_dim_trend, n_dim_obs)

initial_state_mean = np.zeros(n_dim_state)
initial_state_covariance = np.ones((n_dim_state, n_dim_state))

kf = KalmanFilter(
    n_dim_obs=n_dim_obs,
    n_dim_state=n_dim_state,
    initial_state_mean=initial_state_mean,
    initial_state_covariance=initial_state_covariance,
    transition_matrices=F,
    observation_matrices=H,
    observation_covariance=1.0,
    transition_covariance=Q)

df = pd.read_csv('data/ダウ平均株価(1985-2020).csv')

n_train = 350
y = df.Close
train, test = y.values[:n_train], y.values[n_train:]

smoothed_state_means, smoothed_state_covs = kf.smooth(train)

pred_o_smoothed = smoothed_state_means.dot(H.T)

plt.plot(y, label="観測値")

pred_y = np.empty(len(test))

current_state = smoothed_state_means[-1]
current_cov = smoothed_state_covs[-1]
for i in range(len(test)):
    current_state, current_cov = kf.filter_update(current_state,
                                                  current_cov,
                                                  observation=None)
    pred_y[i] = kf.observation_matrices.dot(current_state)

plt.plot(np.hstack([pred_o_smoothed.flatten(), pred_y]), '--', label="予測値")
plt.legend(fontsize = 14)
計算結果

上記のコードによって、ダウ平均株価の訓練用データにフィッティングを行い、トレンドを抽出して予測ができるようになりました。その結果が下の図で、おおむねテスト用データに対する上昇傾向はうまく捉えられていることが分かります。これが周期的に振動したデータであれば、季節ごとの変化を組み込むなどの処理が考えられますが、株価には恐らく周期的なモデルは適用できなそうなので、とりあえずここで解析をとめておきましょう。

f:id:Dajiro:20200504161429p:plain
カルマンフィルタによるダウ平均株価のトレンド予測。破線の直線部分が予測に対応。

MCMCによるベイズ統計モデリング

ベイズ統計モデリングとは、ある事象が観測されたときにそのモデルに含まれるパラメータの妥当な分布を求める手法です。事象が観測される前にはどのようなパラメータが最適化であるか推測できないため、人が設定した分布(無情報事前分布と呼びます)から適当にサンプリングするしかありません。しかし観測データがあれば、その事象の発生確率を上げるような最適なパラメータの分布(事後分布)を計算できるようになります。
ベイズの定理から、この事後分布p(\theta|{\rm Y})は以下の比例関係に従うことが分かります。

p(\theta|{\rm Y}) \propto p({\rm Y}|\theta)

この式の右辺を求めるのに、MCMC(マルコフ連鎖モンテカルロ法)が使われ事後分布が得られます。もちろんこれはパラメータ推定なので、これ自体で予測を行っているわけではありません。ベイズ推定による予測値は以下の数式で表現できます。

p_{pred} (y|Y) = \int p(y|\theta)p(\theta|Y)d\theta

この積分計算の中に含まれているp(\theta|Y)をMCMCサンプルとして計算し、既知の関数であるp(y|\theta)と掛け合わせ積分を和で置き換えることによって予測値を計算できるようになります。これは予測値の平均と信頼区間を計算できるため、予測値が信頼に足るかどうかを判断することができます。ここで紹介するPyStanはこのMCMCを計算し、予測値の平均と分散を計算してくれる便利なライブラリです。

モデルはカルマンフィルタの場合と同様ですが、下のPyStanの記述の可読性を上げるため、モデルを記載しておきます。
真の状態(システム):\mu_{t} \sim {\rm Normal} (2\mu_{t-1} - \mu_{t-2}, \sigma_{t})
観測値:Y_{t} \sim {\rm Normal}(\mu_{t}, \sigma_{Y})

以下がPyStanによるコードです。PyStanのインストールはAnaconda環境で開発する場合conda install -c conda-forge pystanで簡単にできます。

import pystan
import pandas as pd
import matplotlib.pyplot as plt

#データの読み込み
df = pd.read_csv('data/ダウ平均株価(1985-2020).csv')

#Stanで処理するためのデータを辞書形式で定義
n_train = 350
n_pred = len(df.Close) - n_train
dat = {'T': n_train, 'T_pred': n_pred, 'Y': df.Close.values[:n_train]}

# ベイズ統計モデルの定義
model = """
    data {
        int T;
        int T_pred;
        vector[T] Y;
    }
    parameters { 
        vector[T] mu;
        real<lower=0> s_mu;
        real<lower=0> s_Y;
    }
    model { 
        mu[3:T] ~ normal(2*mu[2:(T-1)] - mu[1:(T-2)], s_mu);
        Y ~ normal(mu, s_Y);
    }
    
    generated quantities {
        vector[T+T_pred] mu_all;
        vector[T_pred] y_pred;
        mu_all[1:T] = mu;
        for (t in 1:T_pred) {
            mu_all[T+t] = normal_rng(2*mu_all[T+t-1] - mu_all[T+t-2], s_mu);
            y_pred[t] = normal_rng(mu_all[T+t], s_Y);
        }
    }
    
"""
#モデルのコンパイル
stm = pystan.StanModel(model_code=model)

n_itr = 5000
n_warmup = 100
chains = 4

# サンプリングの実行
fit = stm.sampling(data=dat, iter=n_itr, chains=chains, warmup=n_warmup, algorithm="NUTS",  control=dict(max_treedepth=20), verbose=False)

コードの核心部分はこのようになります。これを実行するとfitに各パラメータの収束の良さを示すRhatが記載され、これが1.1未満であればMCMCが収束したと考えて良いようです。

f:id:Dajiro:20200504212056p:plain
MCMCの結果を出力。Rhatの値に注意。
また上記の計算ではMCMCサンプラーとしてNUTSを採用しましたが、他にはハミルトニアンモンテカルロ法なども使えるようです。max_treedepthはデフォルトでは10ですが、それでは計算が収束しなかったため、上記では20としています。チェーン数は4が推奨で、イテレーションは当初500としており、計算が収束するのを確認してから5000まで引き上げ、計算の精度を上げました(つらつらと書きましたが、詳細はこの記事の冒頭で紹介したベイズ統計本(アヒル本)を参照してください)。ちなみにこの計算は80分ほどかかりました(並列化なし)。カルマンフィルタの計算などに比べると桁違いに時間がかかります。

計算結果

結果を見てみましょう。まずはMCMCがきちんと計算できているか確認します。これは訓練データの最初の点の推定結果で、5000イテレーションを4チェーンで、ウォームアップに各チェーンで100ずつ無視しているので計19600点分のデータが含まれています。各データはノイジーで無相関に見え、収束の指標であるRhatも全てのデータ点で1.0という成績です。MCMCは良く収束しているようです。

f:id:Dajiro:20200504213543p:plain
MCMCの結果。4つのチェーンの結果をマージしている。ウォームアップは無視。
待ちに待った予測値と、誤差(標準偏差)はこちらになります。訓練データには良く適合できており誤差も小さいのですが、訓練に使ってないデータの予測精度となると誤差が爆発的に増加していっていますね。しかし、平均値を見ると上昇傾向にあるので株価推移のトレンドは把握できていると言えそうです。またトレンドの傾きはカルマンフィルタを用いた結果よりも急峻で、より現実の株価に近い結果となりました。まぁ、かとって株価推移に潜む隠れたパターンを抽出できた、とするのは早計ですが。
f:id:Dajiro:20200504214331p:plain
MCMCを用いて計算されたダウ平均株価の予測値と標準偏差。
プロットに使ったコードはこちらです。

import japanize_matplotlib
la  = fit.extract(permuted=True)

std = [la['mu_all'][:, i].std() for i in range(422)]
upper = [mean_all[i] + std[i] for i in range(422)]
lower = [mean_all[i] - std[i] for i in range(422)]

plt.plot(df.Close.values, label = '観測値')
plt.plot(mean_all, "--", label = "予想値")
plt.fill_between(range(422), lower, upper, facecolor='y',alpha=0.5)
plt.ylim(0, 30000)
plt.legend(fontsize=14)
plt.savefig('MCMC.png', bbox_inches='tight', dpi=300)

まとめ

カルマンフィルタ、MCMCいずれも適当なモデルを設定することで予測値の計算までもっていくことができました。今回は教科書的なモデルを構築しましたが、設計者次第で様々なモデルを作ることができ、それによって結果も変わってくることでしょう。そのどれが真実か、ということは間違った問いでいずれも人間による恣意的なモデルなのですが、時々刻々と更新される現実の観測値とモデルの予測値を比較して、それを表現するようにモデルを絶えず更新していくという努力が必要でしょう。
(余談)ちなみに、難しい話はおいといてPyStan面白いですね。イテレーションの数だけモデルの平均値なり標準偏差を計算できるので、重ねてプロットすると以下のような図が得られるわけです。結果の解析が楽しい!

f:id:Dajiro:20200504220139p:plain
テストデータに対する19600イテレーションの結果を重ねてプロット。
f:id:Dajiro:20200504220239p:plain
訓練データに対する19600イテレーションの結果を重ねてプロット。