PyMC3 failing to broadcast correct dimensions for inference

  bayesian, matrix, pymc3, python, theano

I am trying to extend the ideas of item response theory to multiple responses. Consider a marketing survey, which asks customers, "what’s the deciding factor in whether or not you purchase product X?" Where answers are {0: price, 1: durability, 2: ease-of-use}.

Here is some synthetic data (rows are customers, columns are products, each cell is the class response.)

responses = np.array([ 
          [0,1,2,1,0],
          [1,1,1,1,1],
          [0,0,2,2,1],
          [1,1,2,2,1],
          [1,1,0,0,0]  
    ])

students = 5
questions = 5
categories = 3

with pm.Model() as model:
    z_student = pm.Normal("z_student", mu=0, sigma=1, shape=(students,categories))
    z_question = pm.Normal("z_question",mu=0, sigma=1, shape=(categories,questions))
    
    # Transformed parameter
    theta = pm.Deterministic("theta", tt.nnet.softmax(z_student - z_question))
     
    # Likelihood
    kij = pm.Categorical("kij", p=theta, observed=responses)
    trace = pm.sample(chains=4)

az.plot_trace(trace, var_names=["z_student", "z_question"], compact=False);

This code produces the following error: ValueError: Input dimension mis-match. (input[0].shape[0] = 5, input[1].shape[0] = 3).

However, when I change the theta line to: theta = pm.Deterministic("theta", tt.nnet.softmax(z_student - z_question.transpose())) the sampler doesn’t instantly failure, rather is samples wrong.

az.summary(trace)

mean    sd  hdi_3%  hdi_97% mcse_mean   mcse_sd ess_mean    ess_sd  ess_bulk    ess_tail    r_hat
z_student[0,0]  0.150   0.893   -1.620  1.752   0.012   0.013   5789.0  2327.0  5771.0  2991.0  1.0
z_student[0,1]  0.393   0.879   -1.319  1.980   0.012   0.012   5150.0  2610.0  5153.0  3195.0  1.0
z_student[0,2]  -0.591  0.915   -2.254  1.108   0.011   0.012   6408.0  2737.0  6415.0  2830.0  1.0
z_student[1,0]  -0.064  0.860   -1.676  1.538   0.011   0.014   5748.0  1942.0  5747.0  2850.0  1.0
z_student[1,1]  0.602   0.864   -0.982  2.185   0.012   0.011   4921.0  3028.0  4920.0  3269.0  1.0
z_student[1,2]  -0.548  0.906   -2.218  1.137   0.012   0.012   6076.0  2870.0  6083.0  3410.0  1.0
z_student[2,0]  -0.166  0.907   -1.974  1.450   0.013   0.014   4681.0  2121.0  4692.0  3108.0  1.0
z_student[2,1]  -0.188  0.875   -1.776  1.472   0.011   0.014   5923.0  2073.0  5945.0  3333.0  1.0
z_student[2,2]  0.344   0.865   -1.288  1.951   0.012   0.012   4828.0  2750.0  4822.0  3039.0  1.0
z_student[3,0]  -0.212  0.892   -1.980  1.395   0.011   0.013   6019.0  2504.0  5996.0  3391.0  1.0
z_student[3,1]  0.097   0.876   -1.573  1.713   0.012   0.013   5304.0  2252.0  5332.0  2971.0  1.0
z_student[3,2]  0.096   0.851   -1.583  1.645   0.011   0.012   5554.0  2678.0  5543.0  3288.0  1.0
z_student[4,0]  0.160   0.881   -1.367  1.947   0.012   0.013   5421.0  2189.0  5413.0  2927.0  1.0
z_student[4,1]  0.414   0.863   -1.255  2.026   0.012   0.012   4900.0  2548.0  4897.0  3248.0  1.0
z_student[4,2]  -0.558  0.901   -2.266  1.130   0.011   0.012   6551.0  2728.0  6582.0  3142.0  1.0
z_question[0,0] -0.179  0.883   -1.795  1.488   0.011   0.015   6317.0  1769.0  6315.0  3389.0  1.0
z_question[0,1] 0.107   0.886   -1.511  1.807   0.012   0.013   5236.0  2431.0  5209.0  3503.0  1.0
z_question[0,2] 0.164   0.878   -1.450  1.834   0.012   0.013   5131.0  2248.0  5106.0  3102.0  1.0
z_question[0,3] 0.186   0.904   -1.450  1.882   0.011   0.014   6228.0  2175.0  6219.0  3335.0  1.0
z_question[0,4] -0.187  0.877   -1.790  1.508   0.011   0.014   5819.0  2089.0  5834.0  3198.0  1.0
z_question[1,0] -0.389  0.849   -1.948  1.219   0.012   0.012   4726.0  2494.0  4713.0  3146.0  1.0
z_question[1,1] -0.600  0.858   -2.249  0.946   0.012   0.011   5093.0  3247.0  5116.0  3312.0  1.0
z_question[1,2] 0.179   0.868   -1.520  1.763   0.012   0.012   5204.0  2514.0  5201.0  3418.0  1.0
z_question[1,3] -0.103  0.862   -1.683  1.561   0.013   0.013   4608.0  2212.0  4615.0  3163.0  1.0
z_question[1,4] -0.381  0.866   -2.047  1.147   0.011   0.012   6181.0  2735.0  6188.0  3038.0  1.0
z_question[2,0] 0.565   0.908   -1.125  2.337   0.012   0.012   6022.0  2879.0  6045.0  3173.0  1.0
z_question[2,1] 0.536   0.923   -1.192  2.241   0.012   0.013   6041.0  2476.0  6046.0  3059.0  1.0
z_question[2,2] -0.325  0.856   -1.918  1.289   0.012   0.012   5429.0  2741.0  5418.0  3004.0  1.0
z_question[2,3] -0.107  0.881   -1.953  1.363   0.012   0.012   5834.0  2545.0  5841.0  3332.0  1.0
z_question[2,4] 0.576   0.910   -1.202  2.253   0.011   0.013   6385.0  2606.0  6371.0  2905.0  1.0
theta[0,0]  0.360   0.173   0.072   0.685   0.003   0.002   4309.0  3774.0  4256.0  2846.0  1.0
theta[0,1]  0.528   0.182   0.208   0.857   0.003   0.002   4949.0  4563.0  4908.0  3050.0  1.0
theta[0,2]  0.113   0.104   0.001   0.304   0.001   0.001   6095.0  4045.0  7146.0  2780.0  1.0
theta[1,0]  0.216   0.144   0.007   0.477   0.002   0.002   6149.0  4576.0  6493.0  3116.0  1.0
theta[1,1]  0.678   0.168   0.381   0.962   0.002   0.002   5954.0  5954.0  6180.0  3320.0  1.0
theta[1,2]  0.107   0.100   0.000   0.294   0.001   0.001   6321.0  3863.0  7623.0  3252.0  1.0
theta[2,0]  0.234   0.150   0.010   0.509   0.002   0.002   6154.0  4352.0  6684.0  3252.0  1.0
theta[2,1]  0.230   0.152   0.005   0.506   0.002   0.001   6885.0  5424.0  6459.0  2923.0  1.0
theta[2,2]  0.536   0.186   0.194   0.858   0.002   0.002   5595.0  5250.0  5622.0  2805.0  1.0
theta[3,0]  0.239   0.157   0.007   0.526   0.002   0.002   5843.0  4627.0  5789.0  2853.0  1.0
theta[3,1]  0.381   0.178   0.065   0.703   0.003   0.002   4927.0  4377.0  5009.0  3315.0  1.0
theta[3,2]  0.380   0.174   0.069   0.692   0.003   0.002   4653.0  4176.0  4624.0  2562.0  1.0
theta[4,0]  0.361   0.175   0.057   0.668   0.002   0.002   5185.0  4637.0  5269.0  2985.0  1.0
theta[4,1]  0.527   0.184   0.186   0.852   0.003   0.002   4614.0  4445.0  4668.0  2497.0  1.0
theta[4,2]  0.111   0.100   0.002   0.303   0.001   0.001   6159.0  3978.0  7520.0  3473.0  1.0

Of note, please reference the theta values learned. Their names include: Theta[0,0]…Theta[0,2],…Theta[4,2]. So, in the first example, what PyMC3 has learned is the strength of relation between (z_student[0] - z_question[0]) and class/response 0.

This is not the effect I wish to accomplish, I want to learn a 3D tensor accounting for every possible {student, question, category} pairing; there should be 74 thetas, not 15, where Theta[0,0,0] refers to the learned value {student_0, question_0, response_0}. However, my code is currently not accomplishing this effect.

Any ideas?

Source: Python Questions

LEAVE A COMMENT