diff --git a/B_clim_L96.npy b/B_clim_L96.npy new file mode 100644 index 00000000..13a5e5f7 Binary files /dev/null and b/B_clim_L96.npy differ diff --git a/B_clim_L96s.npy b/B_clim_L96s.npy new file mode 100644 index 00000000..3abb588e Binary files /dev/null and b/B_clim_L96s.npy differ diff --git a/DA_demo_L96.ipynb b/DA_demo_L96.ipynb new file mode 100644 index 00000000..8b18e279 --- /dev/null +++ b/DA_demo_L96.ipynb @@ -0,0 +1,429 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "9ec02ae0", + "metadata": {}, + "outputs": [], + "source": [ + "import inspect\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from L96_model import L96, L96s, L96_eq1_xdot, L96_2t_xdot_ydot, RK4\n", + "import time\n", + "from numba import jit\n", + "\n", + "rng=np.random.default_rng()\n", + "\n", + "config=dict(K=40,J=10,obs_freq=10,\n", + " F_truth=10,#+np.concatenate((np.linspace(-1.8,2,20),np.linspace(1.8,-2,20))),\n", + " F_fcst=10,#+np.concatenate((np.linspace(-1.8,2,20),np.linspace(2,-1.8,20))),\n", + " GCM_param=np.array([0,0,0,0]),ns_da=20000,\n", + " ns=20000,ns_spinup=200,dt=0.005,si=0.005,B_loc=5,DA='EnKF',nens=100,\n", + " inflate_opt=\"relaxation\",inflate_factor=0.2,hybrid_factor=0.1,\n", + " param_DA=False,param_sd=[0.01,0.02,0.1,0.5],param_inflate='multiplicative',param_inf_factor=0.02,\n", + " obs_density=0.2,DA_freq=10,obs_sigma=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "986374c1", + "metadata": {}, + "outputs": [], + "source": [ + "def s(k,K):\n", + " \"\"\"A non-dimension coordinate from -1..+1 corresponding to k=0..K\"\"\"\n", + " return 2 * ( 0.5 + k ) / K - 1\n", + "\n", + "def get_dist(i,j,K):\n", + " return abs(i-j) if abs(i-j)<=K/2 else K-abs(i-j)\n", + "\n", + "def GCM(X0, F, dt, nt, param=[0]):\n", + " time, hist, X = dt*np.arange(nt+1), np.zeros((nt+1,len(X0)))*np.nan, X0.copy()\n", + " hist[0] = X\n", + " \n", + " for n in range(nt):\n", + " X = X + dt * ( L96_eq1_xdot(X, F) - np.polyval(param, X) )\n", + "\n", + " hist[n+1], time[n+1] = X, dt*(n+1)\n", + " return hist, time\n", + "\n", + "# Generate observation operator, assuming linearity and model space observations\n", + "def ObsOp(K,l_obs,t_obs,i_t):\n", + " nobs=l_obs.shape[-1]\n", + " H=np.zeros((nobs,K))\n", + " H[range(nobs),l_obs[t_obs==i_t]]=1\n", + " return H\n", + "\n", + "# localize covariance matrix based on the Gaspari-Cohn function\n", + "def cov_loc(B,loc=0):\n", + " M,N = B.shape\n", + " X,Y = np.ix_(np.arange(M),np.arange(N))\n", + " dist=np.vectorize(get_dist)(X,Y,M)\n", + " W=np.vectorize(gaspari_cohn)(dist,loc)\n", + " return B*W,W\n", + "\n", + "def gaspari_cohn(distance,radius):\n", + " if distance==0:\n", + " weight=1\n", + " else: \n", + " if radius==0:\n", + " weight=0\n", + " else:\n", + " ratio=distance/radius\n", + " weight=0\n", + " if ratio<=1:\n", + " weight=-ratio**5/4+ratio**4/2+5*ratio**3/8-5*ratio**2/3+1\n", + " elif ratio<=2:\n", + " weight=ratio**5/12-ratio**4/2+5*ratio**3/8+5*ratio**2/3-5*ratio+4-2/3/ratio\n", + " return weight\n", + "\n", + "def find_obs(loc,obs,t_obs,l_obs,period):\n", + " t_period=np.where((t_obs[:,0]>period[0]) & (t_obs[:,0]<=period[1]))\n", + " obs_period=np.zeros(t_period[0].shape)\n", + " obs_period[:]=np.nan\n", + " for i in np.arange(len(obs_period)):\n", + " if np.any(l_obs[t_period[0][i]]==loc):\n", + " obs_period[i]=obs[t_period[0][i]][l_obs[t_period[0][i]]==loc]\n", + " return obs_period\n", + "\n", + "def running_ave(X,N):\n", + " if N%2==0:\n", + " N1,N2=-N/2,N/2\n", + " else:\n", + " N1,N2=-(N-1)/2,(N+1)/2\n", + " \n", + " X_sum=np.zeros(X.shape)\n", + " for i in np.arange(N1,N2):\n", + " X_sum=X_sum+np.roll(X,int(i),axis=0)\n", + " return X_sum/N " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1245b428", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# Set up the \"truth\" 2-scale L96 model and generate initial conditions from a short spinup\n", + "M_truth = L96(config['K'], config['J'], F=config['F_truth'], dt=config['dt'])\n", + "M_truth.set_state(rng.standard_normal((config['K'])), 0*M_truth.j)\n", + "X_spinup,Y_spinup,t_spinup = M_truth.run(config['si'], config['si']*config['ns_spinup'])\n", + "X_init=X_spinup[-1,:]\n", + "Y_init=Y_spinup[-1,:]\n", + "\n", + "# Run L96 to generate the \"truth\"\n", + "M_truth.set_state(X_init, Y_init)\n", + "\n", + "# Give F a \"seasonal cycle\" in the truth model\n", + "ann_period=2000\n", + "mon_period=100\n", + "mon_per_ann=ann_period/mon_period\n", + "X_truth,Y_truth,t_truth = M_truth.run(config['si'], config['si']*mon_period)\n", + "for i in range(1,int(config['ns']/mon_period)):\n", + " M_truth.set_state(X_truth[-1,...], Y_truth[-1,...])\n", + " M_truth.set_param(F=config['F_truth']+2*np.sin(2*np.pi*i/mon_per_ann))\n", + " X_step,Y_step,t_step = M_truth.run(config['si'], config['si']*mon_period)\n", + " X_truth=np.concatenate((X_truth,X_step[1:None,...]))\n", + " Y_truth=np.concatenate((Y_truth,Y_step[1:None,...]))\n", + " t_truth=np.concatenate((t_truth,t_truth[-1]+t_step[1:None]))\n", + "\n", + "# X_truth,Y_truth,t_truth = M_truth.run(config['si'], config['si']*config['ns'])\n", + "\n", + "# # generate climatological background covariance for 2-scale L96 model\n", + "# B_clim = np.cov(X_truth.T)\n", + "# np.save('B_clim_L96.npy', B_clim) \n", + "\n", + "plt.figure(figsize=(12,10))\n", + "plt.subplot(221); # Snapshot of X[k]\n", + "plt.plot(M_truth.k, X_truth[-1,:], label='X');\n", + "plt.plot(M_truth.j/M_truth.J, Y_truth[-1,:], label='Y')\n", + "plt.legend(); plt.xlabel('k'); plt.title('$X,Y$ @ $t=N\\Delta t$');\n", + "plt.plot(M_truth.k, X_truth[0,:], 'k:')\n", + "plt.plot(M_truth.j/M_truth.J, Y_truth[0,:], 'k:')\n", + "plt.subplot(222); # Sample time-series X[0](t), Y[0](t)\n", + "plt.plot(t_truth, X_truth[:,0], label='X');\n", + "plt.plot(t_truth, Y_truth[:,0], label='Y');\n", + "plt.xlabel('t'); plt.title('$X[0,t]$, $Y[0,t]$');\n", + "plt.subplot(223); # Full model history of X\n", + "plt.contourf(M_truth.k,t_truth,X_truth); plt.colorbar(orientation='horizontal'); plt.xlabel('k'); plt.ylabel('t'); plt.title('$X(k,t)$');\n", + "plt.subplot(224); # Full model history of Y\n", + "plt.contourf(M_truth.j/M_truth.J,t_truth,Y_truth); plt.colorbar(orientation='horizontal'); plt.xlabel('k'); plt.ylabel('t'); plt.title('$Y(k,t)$');\n", + "\n", + "# # generate climatological background covariance for 1-scale L96 model\n", + "# M_1s = L96s(config['K'], F=config['F_truth'], dt=config['dt'], method=RK4)\n", + "# M_1s.set_state(X_init)\n", + "# X1_truth,t1_truth = M_1s.run(config['si']*config['ns'])\n", + "# B_clim_1s = np.cov(X1_truth.T)\n", + "# np.save('B_clim_1s.npy', B_clim_1s) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "norman-republican", + "metadata": {}, + "outputs": [], + "source": [ + "# Sample the \"truth\" to generate observations at certain times (t_obs) and locations (l_obs)\n", + "t_obs=np.tile(np.arange(config['obs_freq'],config['ns_da']+config['obs_freq'],config['obs_freq']),[int(config['K']*config['obs_density']),1]).T\n", + "l_obs=np.zeros(t_obs.shape,dtype='int')\n", + "for i in range(l_obs.shape[0]):\n", + " l_obs[i,:]=rng.choice(config['K'], int(config['K']*config['obs_density']), replace=False)\n", + "X_obs=X_truth[t_obs,l_obs]+config['obs_sigma']*rng.standard_normal(l_obs.shape)\n", + "# print(X_obs.shape)\n", + "\n", + "# Calculated observation covariance matrix, assuming independent observations\n", + "R = config['obs_sigma']**2*np.eye(int(config['K']*config['obs_density']))\n", + "\n", + "# plt.figure(figsize=[6,4])\n", + "# plt.scatter(t_obs,X_obs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b71cd12d", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "import DA_methods\n", + "import importlib\n", + "importlib.reload(DA_methods)\n", + "\n", + "t0 = time.time()\n", + "\n", + "# load pre-calculated climatological background covariance matrix from a long simulation\n", + "B_clim=np.load('B_clim_L96s.npy')\n", + "B_loc,W_clim=cov_loc(B_clim,loc=config['B_loc'])\n", + "\n", + "# set up array to store DA increments\n", + "X_inc=np.zeros((int(config['ns_da']/config['DA_freq']),config['K'],config['nens']))\n", + "if config['DA']=='3DVar':\n", + " X_inc=np.squeeze(X_inc)\n", + "t_DA=np.zeros(int(config['ns_da']/config['DA_freq']))\n", + "\n", + "# initialize ensemble with perturbations\n", + "i_t=0\n", + "ensX=X_init[None,:,None]+rng.standard_normal((1,config['K'],config['nens']))\n", + "X_post=ensX[0,...]\n", + "\n", + "if config['param_DA']:\n", + " mean_param=np.zeros((int(config['ns_da']/config['DA_freq']),len(config['GCM_param'])))\n", + " spread_param=np.zeros((int(config['ns_da']/config['DA_freq']),len(config['GCM_param'])))\n", + " param_scale=config['param_sd']\n", + " W=np.ones((config['K']+len(config['GCM_param']),config['K']+len(config['GCM_param'])))\n", + " W[0:config['K'],0:config['K']]=W_clim\n", + " \n", + "else: \n", + " W=W_clim\n", + " param_scale=np.zeros(config['GCM_param'].shape)\n", + " \n", + "ens_param=np.zeros((len(config['GCM_param']),config['nens']))\n", + "for i in range(len(config['GCM_param'])):\n", + " ens_param[i,:]=config['GCM_param'][i]+rng.normal(scale=param_scale[i],size=config['nens'])\n", + " \n", + "# DA cycles\n", + "for cycle in np.arange(0,config['ns_da']/config['DA_freq'],dtype='int'):\n", + "# for cycle in np.arange(0,1,dtype='int'):\n", + " \n", + " # set up array to store model forecast for each DA cycle\n", + " ensX_fcst=np.zeros((config['DA_freq']+1,config['K'],config['nens']))\n", + "\n", + " for n in range(config['nens']):\n", + " ensX_fcst[...,n] = GCM(X_post[0:config['K'],n], config['F_fcst'], config['dt'], config['DA_freq'], ens_param[:,n])[0]\n", + " i_t=i_t+config['DA_freq']\n", + "\n", + " X_prior=ensX_fcst[-1,...] # get prior from model integration\n", + " \n", + " # call DA\n", + " t_DA[cycle]=t_truth[i_t]\n", + " if config['DA']=='EnKF':\n", + " H=ObsOp(config['K'],l_obs,t_obs,i_t)\n", + " # augment state vector with parameters when doing parameter estimation\n", + " if config['param_DA']:\n", + " H=np.concatenate((H,np.zeros((H.shape[0],len(config['GCM_param'])))),axis=-1)\n", + " X_prior=np.concatenate((X_prior,ens_param))\n", + " B_ens = np.cov(X_prior)\n", + " B_ens_loc = B_ens*W\n", + " X_post=DA_methods.EnKF(X_prior,X_obs[t_obs==i_t],H,R,B_ens_loc)\n", + " X_post[0:config['K'],:]=DA_methods.ens_inflate(X_post[0:config['K'],:],X_prior[0:config['K'],:],\n", + " config['inflate_opt'],config['inflate_factor'])\n", + " X_post[-len(config['GCM_param']):None,:]=DA_methods.ens_inflate(X_post[-len(config['GCM_param']):None,:],\n", + " X_prior[-len(config['GCM_param']):None,:],\n", + " config['param_inflate'],\n", + " config['param_inf_factor']) \n", + " if config['param_DA']:\n", + " ens_param=X_post[-len(config['GCM_param']):None,:]\n", + " elif config['DA']=='HyEnKF':\n", + " H=ObsOp(config['K'],l_obs,t_obs,i_t)\n", + " B_ens = np.cov(X_prior)*(1-config['hybrid_factor'])+B_clim*config['hybrid_factor']\n", + " B_ens_loc = B_ens*W\n", + " X_post=DA_methods.EnKF(X_prior,X_obs[t_obs==i_t],H,R,B_ens_loc)\n", + " X_post=DA_methods.ens_inflate(X_post,X_prior,config['inflate_opt'],config['inflate_factor'])\n", + " elif config['DA']=='3DVar':\n", + " X_prior=np.squeeze(X_prior)\n", + " H=ObsOp(config['K'],l_obs,t_obs,i_t)\n", + " X_post=DA_methods.Lin3dvar(X_prior,X_obs[t_obs==i_t],H,R,B_loc,3)\n", + " X_post=X_post[:,None]\n", + " elif config['DA']=='Replace':\n", + " X_post=X_prior\n", + " X_post[l_obs[t_obs==i_t]]=X_obs[t_obs==i_t]\n", + " elif config['DA']=='None':\n", + " X_post=X_prior\n", + " \n", + " if not config['DA']=='None':\n", + " X_inc[cycle,...]=np.squeeze(X_post[0:config['K'],...])-X_prior[0:config['K'],...] # get current increments\n", + " # get posterior info about the estimated parameters\n", + " if config['param_DA']:\n", + " mean_param[cycle,:]=ens_param.mean(axis=-1)\n", + " spread_param[cycle,:]=ens_param.std(axis=-1)\n", + " \n", + " # reset initial conditions for next DA cycle\n", + " ensX_fcst[-1,:,:]=X_post[0:config['K'],:]\n", + " ensX=np.concatenate((ensX,ensX_fcst[1:None,...]))\n", + " \n", + "if config['DA']=='3DVar':\n", + " X_inc=X_inc[...,None]\n", + "\n", + "print(X_inc.shape)\n", + "t1 = time.time()\n", + "print(t1-t0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "180f50f0", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# post processing and visualization\n", + "meanX=np.mean(ensX,axis=-1)\n", + "clim=np.max(np.abs(meanX-X_truth[0:(config['ns_da']+1),:]))\n", + "\n", + "fig, axes=plt.subplots(2,3,figsize=(15,10))\n", + "ch=axes[0,0].contourf(M_truth.k,t_truth[0:(config['ns_da']+1)],meanX-X_truth[0:(config['ns_da']+1),:],\n", + " cmap='bwr',vmin=-clim,vmax=clim,extend='neither');\n", + "plt.colorbar(ch,ax=axes[0,0],orientation='horizontal'); \n", + "axes[0,0].set_xlabel('s'); axes[0,0].set_ylabel('t'); axes[0,0].set_title('X - X_truth');\n", + "axes[0,1].plot(t_truth[0:(config['ns_da']+1)], np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean(axis=-1)),label='RMSE'); \n", + "axes[0,1].plot(t_truth[0:(config['ns_da']+1)], np.mean(np.std(ensX,axis=-1),axis=-1),label='Spread'); \n", + "axes[0,1].legend()\n", + "axes[0,1].set_xlabel('t'); axes[0,1].set_title('RMSE (X - X_truth)');\n", + "axes[0,1].grid(which='both',linestyle='--')\n", + "\n", + "# axes[0,2].plot(M_truth.k, np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean(axis=0)),label='RMSE'); \n", + "# X_inc_ave=X_inc/config['DA_freq']/config['si']\n", + "# axes[0,2].plot(M_truth.k, X_inc_ave.mean(axis=(0,-1)),label='Inc'); \n", + "# axes[0,2].plot(M_truth.k, running_ave(X_inc_ave.mean(axis=(0,-1)),7),label='Inc Ave'); \n", + "# axes[0,2].plot(M_truth.k, np.ones(M_truth.k.shape)*(config['F_fcst']-config['F_truth']),label='F_bias'); \n", + "# axes[0,2].plot(M_truth.k, np.ones(M_truth.k.shape)*(X_inc/config['DA_freq']/config['si']).mean(),'k:',label='Ave Inc'); \n", + "# axes[0,2].legend()\n", + "# axes[0,2].set_xlabel('s'); axes[0,2].set_title('Increments');\n", + "# axes[0,2].grid(which='both',linestyle='--')\n", + "\n", + "X_inc_ave=(X_inc/config['DA_freq']/config['si']).mean(axis=(1,2)).\\\n", + " reshape(int(config['ns_da']/ann_period),int(ann_period/config['DA_freq'])).mean(axis=0)\n", + "axes[0,2].plot(np.arange(ann_period/config['DA_freq']),X_inc_ave,label='Inc')\n", + "axes[0,2].plot(np.arange(ann_period/config['DA_freq']),running_ave(X_inc_ave,10),label='Inc Ave');\n", + "axes[0,2].plot(np.arange(0,ann_period/config['DA_freq'],mon_period/config['DA_freq']),\n", + " -2*np.sin(2*np.pi*np.arange(mon_per_ann)/mon_per_ann),label='F_bias')\n", + "axes[0,2].legend()\n", + "axes[0,2].set_xlabel('\"annual cycle\"'); axes[0,2].set_title('Increments');\n", + "axes[0,2].grid(which='both',linestyle='--')\n", + "\n", + "plot_start,plot_end=1000, 1400\n", + "plot_start_DA, plot_end_DA=int(plot_start/config['DA_freq']), int(plot_end/config['DA_freq'])\n", + "plot_x=0\n", + "axes[1,0].plot(t_truth[plot_start:plot_end],X_truth[plot_start:plot_end,plot_x],label='truth')\n", + "axes[1,0].plot(t_truth[plot_start:plot_end],meanX[plot_start:plot_end,plot_x],label='forecast')\n", + "axes[1,0].scatter(t_DA[plot_start_DA:plot_end_DA],find_obs(plot_x,X_obs,t_obs,l_obs,[plot_start,plot_end]),label='obs')\n", + "axes[1,0].grid(which='both',linestyle='--')\n", + "axes[1,0].legend()\n", + "\n", + "if config['param_DA']:\n", + " for i,c in zip(np.arange(len(config['GCM_param']),0,-1),['r','b','g','k']):\n", + " axes[1,1].plot(t_DA,running_ave(mean_param[:,i-1],100),c+'-',\n", + " label='C{} {:3f}'.format(i-1,mean_param[int(len(t_DA)/2):None,i-1].mean()))\n", + " axes[1,1].plot(t_DA,running_ave(mean_param[:,i-1]+spread_param[:,i-1],100),c+':',\n", + " label='SD {:3f}'.format(spread_param[int(len(t_DA)/2):None,i-1].mean()))\n", + " axes[1,1].plot(t_DA,running_ave(mean_param[:,i-1]-spread_param[:,i-1],100),c+':')\n", + " axes[1,1].legend()\n", + " axes[1,1].grid(which='both',linestyle='--')\n", + "\n", + "axes[1,2].text(0.1,0.1,'GCM param={}\\nRMSE={:3f}\\nSpread={:3f}\\nDA={},{},{}\\nDA_freq={}\\nB_loc={}\\ninflation={},{}\\nobs_density={}\\nobs_sigma={}\\nobs_freq={}'.\\\n", + " format(config['GCM_param'],np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean()),\n", + " np.mean(np.std(ensX,axis=-1)),config['DA'],\n", + " config['nens'],config['hybrid_factor'],config['DA_freq'],config['B_loc'],\n", + " config['inflate_opt'],config['inflate_factor'],config['obs_density'],config['obs_sigma'],\n", + " config['obs_freq']),\n", + " fontsize=15)\n", + "\n", + "# exp_number=np.random.randint(1,10000)\n", + "# f = open('config_{0}.txt'.format(exp_number),\"w\")\n", + "# f.write( str(config) )\n", + "# f.close()\n", + "# plt.savefig('fig_{0}.jpg'.format(exp_number))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4f6238e", + "metadata": {}, + "outputs": [], + "source": [ + "# B_clim1=np.load('B_clim_L96s.npy')\n", + "# B_clim2=np.load('B_clim_L96.npy')\n", + "# B_corr1=np.zeros(B_clim1.shape)\n", + "# B_corr2=np.zeros(B_clim2.shape)\n", + "# for i in range(40):\n", + "# for j in range(40):\n", + "# B_corr1[i,j]=B_clim1[i,j]/np.sqrt(B_clim1[i,i]*B_clim1[j,j])\n", + "# B_corr2[i,j]=B_clim2[i,j]/np.sqrt(B_clim2[i,i]*B_clim2[j,j])\n", + " \n", + "# print(B_corr)\n", + "# plt.figure(figsize=(16,6))\n", + "# plt.subplot(121)\n", + "# plt.contourf(B_corr1,cmap='bwr',extend='both',levels=np.linspace(-0.95,0.95,20))\n", + "# plt.colorbar()\n", + "# plt.title('Background correlation matrix 1-scale L96')\n", + "# plt.subplot(122)\n", + "# plt.contourf(B_corr2,cmap='bwr',extend='both',levels=np.linspace(-0.95,0.95,20))\n", + "# plt.colorbar()\n", + "# plt.title('Background correlation matrix 2-scale L96')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/DA_methods.py b/DA_methods.py new file mode 100644 index 00000000..34dccd01 --- /dev/null +++ b/DA_methods.py @@ -0,0 +1,83 @@ +""" Data assimilation methods +Adapted form PyDA project: https://github.com/Shady-Ahmed/PyDA +Reference: https://www.mdpi.com/2311-5521/5/4/225 +""" +import numpy as np +from numba import jit + +@jit +def Lin3dvar(ub,w,H,R,B,opt): + + # The solution of the 3DVAR problem in the linear case requires + # the solution of a linear system of equations. + # Here we utilize the built-in numpy function to do this. + # Other schemes can be used, instead. + + if opt == 1: #model-space approach + Bi = np.linalg.inv(B) + Ri = np.linalg.inv(R) + A = Bi + (H.T)@Ri@H + b = Bi@ub + (H.T)@Ri@w + ua = np.linalg.solve(A,b) #solve a linear system + + elif opt == 2: #model-space incremental approach + + Bi = np.linalg.inv(B) + Ri = np.linalg.inv(R) + A = Bi + (H.T)@Ri@H + b = (H.T)@Ri@(w-H@ub) + ua = ub + np.linalg.solve(A,b) #solve a linear system + + + elif opt == 3: #observation-space incremental approach + + A = R + H@B@(H.T) + b = (w-H@ub) + ua = ub + B@(H.T)@np.linalg.solve(A,b) #solve a linear system + + return ua + +@jit +def ens_inflate(posterior,prior,opt,factor): + + inflated=np.zeros(posterior.shape) + n,N=prior.shape + if opt == "multiplicative": + mean_post=(posterior.sum(axis=-1)/N).repeat(N).reshape(n,N) + inflated=posterior+factor*(posterior-mean_post) + + elif opt == "relaxation": + mean_prior=(prior.sum(axis=-1)/N).repeat(N).reshape(n,N) + mean_post=(posterior.sum(axis=-1)/N).repeat(N).reshape(n,N) + inflated=mean_post+(1-factor)*(posterior-mean_post)+factor*(prior-mean_prior) + + return inflated + + +@jit +def EnKF(prior,obs,H,R,B): + + # The analysis step for the (stochastic) ensemble Kalman filter + # with virtual observations + + n,N = prior.shape # n is the state dimension and N is the size of ensemble + m = obs.shape[0] # m is the size of measurement vector + + mR = R.shape[0] + nB = B.shape[0] + mH, nH = H.shape + assert m==mR, "obseravtion and obs_cov_matrix have incompatible size" + assert nB==n, "state and state_cov_matrix have incompatible size" + assert m==mH, "obseravtion and obs operator have incompatible size" + assert n==nH, "state and obs operator have incompatible size" + + # compute Kalman gain + D = H@B@H.T + R + K = B @ H.T @ np.linalg.inv(D) + + # perturb observations + obs_ens=obs.repeat(N).reshape(m,N)+np.sqrt(R)@np.random.standard_normal((m,N)) + # compute analysis ensemble + posterior = prior + K @ (obs_ens-H@prior) + + return posterior \ No newline at end of file diff --git a/L96_model.py b/L96_model.py index db24d25e..957cefef 100644 --- a/L96_model.py +++ b/L96_model.py @@ -23,8 +23,9 @@ def L96_eq1_xdot(X, F): K = len(X) Xdot = np.zeros(K) - for k in range(K): - Xdot[k] = ( X[(k+1)%K] - X[k-2] ) * X[k-1] - X[k] + F + Xdot = np.roll(X,1) * ( np.roll(X,-1) - np.roll(X,2) ) - X + F +# for k in range(K): +# Xdot[k] = ( X[(k+1)%K] - X[k-2] ) * X[k-1] - X[k] + F return Xdot @jit @@ -54,9 +55,9 @@ def L96_2t_xdot_ydot(X, Y, F, h, b, c): Ysummed = Y.reshape((K,J)).sum(axis=-1) - #Xdot = np.roll(X,1) * ( np.roll(X,-1) - np.roll(X,2) ) - X + F - hcb * Ysummed - for k in range(K): - Xdot[k] = ( X[(k+1)%K] - X[k-2] ) * X[k-1] - X[k] + F - hcb * Ysummed[k] + Xdot = np.roll(X,1) * ( np.roll(X,-1) - np.roll(X,2) ) - X + F - hcb * Ysummed +# for k in range(K): +# Xdot[k] = ( X[(k+1)%K] - X[k-2] ) * X[k-1] - X[k] + F - hcb * Ysummed[k] #for j in range(JK): # k = j//J