|
139 | 139 | "inter1= [ -.01434325, -.01460965, 0, 0, 0, -.01113493, 0, 0, 0, -.0553269, -.03238896, 0, 0, -.07062459, -.07464545, -.07032613, 0, 0, -.01408955, 0, -.00219072, 0, 0, 0, 0, 0, .07300876, .01394272, 0, 0, 0, 0, 0, 0, .05120398, 0, -.00550709, -.02062663, -.03077685, -.01688493, 0, .01149963, 0, .01149963, .01149963, 0, 0, 0, 0, 0, 0, 0, 0, 0, .01149963, .0034338, .0376236, .00733331, 0, .03832785, .03832785, -.02622275, -.02622275, -.02622275, -.01492678, 0, 0, -.02897806, -.02847666, 0, 0, -.04224754, -.04743705, -.0510477, -.031893, 0, 0, 0, -.01503116, .003101, -.00083466, .02395027, -.07952866, 0, 0, -.06586029, 0, -.0613939, -.081205, -.07540084, -.08488011, -.08488011, 0, -.07492433, -.08907269, -.09451609, 0, -.08980743, 0, -.0771635, 0, 0, -.0771635, -.08204606, 0, -.05263504, 0, -.05109092, -.04696729, 0, -.04696729, 0, -.05303248, -.05348096, 0, 0, .00584956, -.00792241, -.01719816, 0, -.01576016, 0, -.04014061, 0, 0, 0, 0, 0, .0471441, 0, .04233112, 0, .04233112, 0, 0, .0493324, .04512087, .03205975, .02913185, 0, .05324252, 0, 0, 0, 0, .05054695, 0, .14026688, .01734403, .06078221, 0, 0, 0, -.03138622, 0, .01637333, 0, 0, 0, 0, .01897239, .01591935, 0, -.0619156, 0, -.06851645, 0, -.03889525, -.05023452, -.05013452, 0, 0, -.01362136, 0, 0, -.02634164, 0, 0, 0, 0, -.00890537, -.00611669, 0, 0, 0, -.01513384, 0, -.03551984, 0, -.01978032, 0, .06706496, .10551275, 0, .03092981, .06556855, 0, 0, 0, .09362991, 0, 0, 0, 0, 0, 0, .02610553, .03546937, 0, 0, .034415, 0, 0, 0, .07546701, 0, 0, 0, 0, -.02919447, -.01016712, 0, 0, 0, 0, -.04845615, -.05010044, 0, 0, 0, 0, 0, 0, -.07666632, 0, 0, -.07226554, -.08216553, -.0777643, 0, 0, -.04727952, 0, -.06870384, -.05999847, 0, 0, 0, .02772475, .02883079, .03642944, 0, .04148949, 0, 0, 0, .04268012, .03225577, 0, -.05140995, -.05399637, 0, 0, .02432223, 0, .0490674, .0490674, .0490674, 0, 0, 0, 0, 0, 0, 0, 0, .10476315, 0, 0, 0, 0, 0, .07008056, 0, 0, .01667466, 0, .05253941, .04293926, 0, .02692172, 0, 0, .08742411, .04533176, 0, .01831875, 0, .09834951, .09952456, 0, .02945534, .038731, 0, .04435538, 0, -.02357505, 0, 0, -.02357505, .09324722, 0, 0, 0, -.03490683, 0, -.05054474, 0, -.0474724, -.04905931, 0, .02879751, 0, 0, 0, 0, 0, 0, 0, .04439012, 0, .02989959, .02989959, .05468828, .04463226, 0, 0, 0, 0, 0, .01231324, -.01399783, .04595331, .00145386, 0, .06459354, -.0007196, 0, -.07614055, -.08435525, 0, -.10299519, 0, 0, 0, -.00210284, -.00797183, 0, 0, 0, 0, -.03545086, 0, 0, 0, 0, -.061286, -.07666647, 0, -.05902354, -.07652324, -.07645561, 0, 0, 0, -.03292062, 0, 0, 0, 0, -.075417, 0, -.07922532, 0, -.08583414, -.07450142, -.08066016, 0, 0, -.06249051, 0, 0, 0, 0, -.0618688, 0, -.06524737, -.04419825, -.04489509, 0, 0, 0, -.04520512, -.04187583, 0, 0, -.03753508, 0, 0, 0, 0, 0, 0, 0, 0, .06862645, 0, 0, -.00120631, .01947345, 0, 0, .03561932, 0, .03158225, .03608047, 0, 0, 0, -.02899643],\n",
|
140 | 140 | "\n",
|
141 | 141 | "inter2= [-.78348798, -.63418788, 0, 0, 0, .11481193, 0, 0, 0, -.88128799, -1.109488, 0, 0, -.30888793, .29651192, -.36688802, 0, 0, -.59088796, 0, .50561196, 0, 0, 0, 0, 0, 2.1662121, .08891205, 0, 0, 0, 0, 0, 0, -.23918791, 0, -.9575879, -.07728811, .29641202, 1.2273121, 0, 1.5764117, 0, .72131211, 1.279212, 0, 0, 0, 0, 0, 0, 0, 0, 0, .36481193, 1.5480118, -.03078791, 1.389112, 0, .70901209, -.16668792, 1.435812, .47001198, 2.0838118, 1.1673121, 0, 0, 1.4470119, .23301201, 0, 0, -.61948794, -.41388795, .263212, .66171199, 0, 0, 0, 1.6920118, 1.334012, 1.2101121, .41591194, -.48498794, 0, 0, .09911207, 0, -.46908805, .0205119, .0535119, -.14228792, -.55708808, 0, -.54008788, -.30998799, -.10958811, 0, -.01338812, 0, -.51788801, 0, 0, .13271193, -.11208793, 0, -.54508799, 0, .16641192, .95871216, 0, 1.6281118, 0, -.49718806, -.41348812, 0, 0, -.11718794, -.57058805, -.59488791, 0, -.65658802, 0, -.52698797, 0, 0, 0, 0, 0, 1.3500118, 0, 1.665812, 0, 1.963912, 0, 0, 1.9371119, .90991193, -.39558789, .39521196, 0, -.05268808, 0, 0, 0, 0, -.12458798, 0, -.28228804, .79281193, -.26358792, 0, 0, 0, -.72828788, 0, .355912, 0, 0, 0, 0, -.43538806, -1.566388, 0, -.28388807, 0, -.69028801, 0, -.78128809, -.54648799, -.92738789, 0, 0, .61571199, 0, 0, 1.012012, 0, 0, 0, 0, .43991187, .9404121, 0, 0, 0, .61671191, 0, 2.6073117, 0, -.60438794, 0, -.18108793, -.48178813, 0, -.22628804, -.07398792, 0, 0, 0, 1.830512, 0, 0, 0, 0, 0, 0, -.36918804, 1.3247118, 0, 0, 1.163012, 0, 0, 0, 1.3241119, 0, 0, 0, 0, -.90038794, -1.250888, 0, 0, 0, 0, -1.048188, -.90138787, 0, 0, 0, 0, 0, 0, -.00878807, 0, 0, .46301201, -.22048803, -.71518797, 0, 0, -.4952881, 0, -.83718795, .57951194, 0, 0, 0, 1.054112, .61721212, 2.2717118, 0, 2.0280118, 0, 0, 0, 2.6503119, 2.3914118, 0, -.36418793, -.9259879, 0, 0, .16971211, 0, 1.9360118, 2.5344119, 2.0171118, 0, 0, 0, 0, 0, 0, 0, 0, .77211195, 0, 0, 0, 0, 0, 1.952312, 0, 0, .25901201, 0, -.5028879, .03641204, 0, .38571194, 0, 0, -.44528791, -.55918807, 0, .39721206, 0, -.34038803, -.05988808, 0, -.03518792, .045512, 0, -.039288, 0, .01431207, 0, 0, .030412, -.31918809, 0, 0, 0, -.324388, 0, -1.232188, 0, -.23678799, -.89188808, 0, -.6766879, 0, 0, 0, 0, 0, 0, 0, -.67818803, 0, 1.099412, 1.2767119, -.64068788, -.50678796, 0, 0, 0, 0, 0, .68371207, .11251191, -.17128797, .17081194, 0, -.48708794, .09591202, 0, -.20108791, -.02158805, 0, -.3012881, 0, 0, 0, -.87638801, -.54488796, 0, 0, 0, 0, -.14738794, 0, 0, 0, 0, -.75718802, -.37418792, 0, 1.0981121, 1.1441121, .47381189, 0, 0, 0, -.33498809, 0, 0, 0, 0, -.91838807, 0, -.34488794, 0, .12971191, .99381214, -.91608804, 0, 0, .98171192, 0, 0, 0, 0, -.01528808, 0, -.41458794, .25691202, .18601207, 0, 0, 0, -.35608789, .79691201, 0, 0, 1.548912, 0, 0, 0, 0, 0, 0, 0, 0, 1.1663117, 0, 0, -.009088, -.49578807, 0, 0, -.2677879, 0, -.25468799, .68631202, 0, 0, 0, -.36198804])\n",
|
142 |
| - "#fmt: off" |
| 142 | + "# fmt: off" |
143 | 143 | ]
|
144 | 144 | },
|
145 | 145 | {
|
|
149 | 149 | "outputs": [],
|
150 | 150 | "source": [
|
151 | 151 | "def load_data_cox(dta):\n",
|
152 |
| - " array = lambda x : np.array(dta[x], dtype=float)\n", |
153 |
| - " t = array('t')\n", |
154 |
| - " obs_t = array('obs_t')\n", |
155 |
| - " pscenter = array('pscenter')\n", |
156 |
| - " hhcenter = array('hhcenter')\n", |
157 |
| - " ncomact = array('ncomact')\n", |
158 |
| - " rleader = array('rleader')\n", |
159 |
| - " dleader = array('dleader')\n", |
160 |
| - " inter1 = array('inter1')\n", |
161 |
| - " inter2 = array('inter2')\n", |
162 |
| - " fail = array('FAIL')\n", |
163 |
| - " return (t, obs_t, pscenter, hhcenter, ncomact,\n", |
164 |
| - " rleader, dleader, inter1, inter2, fail)" |
| 152 | + " array = lambda x: np.array(dta[x], dtype=float)\n", |
| 153 | + " t = array(\"t\")\n", |
| 154 | + " obs_t = array(\"obs_t\")\n", |
| 155 | + " pscenter = array(\"pscenter\")\n", |
| 156 | + " hhcenter = array(\"hhcenter\")\n", |
| 157 | + " ncomact = array(\"ncomact\")\n", |
| 158 | + " rleader = array(\"rleader\")\n", |
| 159 | + " dleader = array(\"dleader\")\n", |
| 160 | + " inter1 = array(\"inter1\")\n", |
| 161 | + " inter2 = array(\"inter2\")\n", |
| 162 | + " fail = array(\"FAIL\")\n", |
| 163 | + " return (t, obs_t, pscenter, hhcenter, ncomact, rleader, dleader, inter1, inter2, fail)" |
165 | 164 | ]
|
166 | 165 | },
|
167 | 166 | {
|
|
170 | 169 | "metadata": {},
|
171 | 170 | "outputs": [],
|
172 | 171 | "source": [
|
173 |
| - "(t, obs_t, pscenter, hhcenter, ncomact, rleader,\n", |
174 |
| - " dleader, inter1, inter2, fail) = load_data_cox(dta)" |
| 172 | + "(t, obs_t, pscenter, hhcenter, ncomact, rleader, dleader, inter1, inter2, fail) = load_data_cox(dta)" |
175 | 173 | ]
|
176 | 174 | },
|
177 | 175 | {
|
|
210 | 208 | "outputs": [],
|
211 | 209 | "source": [
|
212 | 210 | "with Model() as model:\n",
|
213 |
| - " \n", |
| 211 | + "\n", |
214 | 212 | " T = len(t) - 1\n",
|
215 | 213 | " nsubj = len(obs_t)\n",
|
216 | 214 | "\n",
|
217 | 215 | " # risk set equals one if obs_t >= t\n",
|
218 | 216 | " Y = np.array([[int(obs >= time) for time in t] for obs in obs_t])\n",
|
219 | 217 | " # counting process. jump = 1 if obs_t \\in [t[j], t[j+1])\n",
|
220 |
| - " dN = np.array([[Y[i,j]*int(t[j+1] >= obs_t[i])*fail[i] for j in range(T)] for i in\n", |
221 |
| - " range(nsubj)])\n", |
| 218 | + " dN = np.array(\n", |
| 219 | + " [[Y[i, j] * int(t[j + 1] >= obs_t[i]) * fail[i] for j in range(T)] for i in range(nsubj)]\n", |
| 220 | + " )\n", |
| 221 | + "\n", |
| 222 | + " c = Gamma(\"c\", 0.0001, 0.00001)\n", |
| 223 | + " r = Gamma(\"r\", 0.001, 0.0001)\n", |
| 224 | + "\n", |
| 225 | + " dL0_star = r * np.diff(t)\n", |
222 | 226 | "\n",
|
223 |
| - " c = Gamma('c', .0001, .00001)\n", |
224 |
| - " r = Gamma('r', .001, .0001)\n", |
225 |
| - " \n", |
226 |
| - " dL0_star = r*np.diff(t)\n", |
227 |
| - " \n", |
228 | 227 | " # prior mean hazard\n",
|
229 |
| - " mu = dL0_star * c \n", |
230 |
| - " \n", |
231 |
| - " dL0 = Gamma('dL0', mu, c, shape=T)\n", |
| 228 | + " mu = dL0_star * c\n", |
| 229 | + "\n", |
| 230 | + " dL0 = Gamma(\"dL0\", mu, c, shape=T)\n", |
232 | 231 | "\n",
|
233 |
| - " beta = Normal('beta', np.zeros(7),\n", |
234 |
| - " np.ones(7)*100, shape=7)\n", |
| 232 | + " beta = Normal(\"beta\", np.zeros(7), np.ones(7) * 100, shape=7)\n", |
235 | 233 | "\n",
|
236 | 234 | " linear_model = tt.exp(tt.dot(X.T, beta))\n",
|
237 | 235 | " idt = Y[:, :-1] * tt.outer(linear_model, dL0)\n",
|
238 | 236 | "\n",
|
239 |
| - " dn_like = Poisson('dn_like', idt, observed=dN)" |
| 237 | + " dn_like = Poisson(\"dn_like\", idt, observed=dN)" |
240 | 238 | ]
|
241 | 239 | },
|
242 | 240 | {
|
|
365 | 363 | ],
|
366 | 364 | "source": [
|
367 | 365 | "with model:\n",
|
368 |
| - " trace = sample(2000, n_init=10000, init='advi_map')" |
| 366 | + " trace = sample(2000, n_init=10000, init=\"advi_map\")" |
369 | 367 | ]
|
370 | 368 | },
|
371 | 369 | {
|
|
390 | 388 | }
|
391 | 389 | ],
|
392 | 390 | "source": [
|
393 |
| - "az.plot_trace(trace, var_names=['c', 'r']);" |
| 391 | + "az.plot_trace(trace, var_names=[\"c\", \"r\"]);" |
394 | 392 | ]
|
395 | 393 | },
|
396 | 394 | {
|
|
415 | 413 | }
|
416 | 414 | ],
|
417 | 415 | "source": [
|
418 |
| - "az.plot_forest(trace, var_names=['beta']);" |
| 416 | + "az.plot_forest(trace, var_names=[\"beta\"]);" |
419 | 417 | ]
|
420 | 418 | },
|
421 | 419 | {
|
|
0 commit comments