fork download
  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <time.h>
  4. #include <math.h>
  5.  
  6. #define NUM 6
  7. #define DIM 4
  8.  
  9. int main(int argc, char *argv[])
  10. {
  11. int num2 = 6;
  12. int dim2 = 4;
  13.  
  14. float *X= (float *) malloc (sizeof(float)*num2*dim2);
  15. float *Wk = (float *) malloc (sizeof(float)*dim2*dim2);
  16. float *Wq = (float *) malloc (sizeof(float)*dim2*dim2);
  17. float *Wv = (float *) malloc (sizeof(float)*dim2*dim2);
  18. float *Bk = (float *) malloc (sizeof(float)*dim2);
  19. float *Bq = (float *) malloc (sizeof(float)*dim2);
  20. float *Bv = (float *) malloc (sizeof(float)*dim2);
  21.  
  22. float *kn = (float *) calloc (num2*dim2,sizeof(float));
  23. float *vn = (float *) calloc (num2*dim2,sizeof(float));
  24. float *qn = (float *) calloc (num2*dim2,sizeof(float));
  25. float *aqkn = (float *) calloc (num2*num2,sizeof(float));
  26. float *alfa = (float *) malloc (sizeof(float)*num2*num2);
  27. float *c = (float *) calloc (num2*dim2,sizeof(float));
  28.  
  29. // ------------------------------------------------------------
  30. // Inicializacion
  31. for (int i = 0; i < num2; i++)
  32. {
  33. for (int j = 0; j < dim2; j++)
  34. {
  35. X[i*dim2+j]= i + 6*j;
  36. }
  37. }
  38. for (int i = 0; i < dim2; i++)
  39. {
  40. for (int j = 0; j < dim2; j++)
  41. {
  42. Wk[j*dim2+i]= -0.2 + 0.1*j;
  43. Wq[j*dim2+i]= -0.2 + 0.1*i;
  44. if(i == j)
  45. {
  46. Wv[i*dim2+j]= 1;
  47. }
  48. else
  49. {
  50. Wv[i*dim2+j]= 0;
  51. }
  52. }
  53. }
  54. for (int i = 0; i < dim2; i++)
  55. {
  56. Bv[i] = 0;
  57. Bq[i] = 0.1;
  58. Bk[i] = -1;
  59. }
  60. // ------------------------------------------------------------
  61.  
  62. for (int i = 0; i < num2; ++i)
  63. {
  64. for (int j = 0; j < dim2; ++j)
  65. {
  66. for (int k = 0; k < dim2; ++k)
  67. {
  68. vn[i*num2+j] += (Wv[i*dim2+k] * X[j*num2+k]);
  69. kn[i*dim2+j] += (Wk[j*dim2+k] * X[i*dim2+k]);
  70. qn[i*dim2+j] += (Wq[j*dim2+k] * X[i*dim2+k]);
  71. }
  72. vn[j*dim2+i] += Bv[j];
  73. kn[i*dim2+j] += Bk[j];
  74. qn[i*dim2+j] += Bq[j];
  75. }
  76. }
  77.  
  78.  
  79. for(int i = 0; i < dim2;i++){
  80. for(int j = 0; j < num2;j++)
  81. printf("%.1f\t",vn[i*num2+j]);
  82. printf("\n");
  83. }
  84. printf("\n");
  85.  
  86. for(int i = 0; i < num2;i++){
  87. for(int j = 0; j < dim2;j++)
  88. printf("%.1f\t",kn[i*dim2+j]);
  89. printf("\n");
  90. }
  91. printf("\n");
  92.  
  93.  
  94. for(int i = 0; i < num2;i++){
  95. for(int j = 0; j < dim2;j++)
  96. printf("%.1f\t",qn[i*dim2+j]);
  97. printf("\n");
  98. }
  99. printf("\n");
  100. /*
  101.   float *sumaqkn = (float *) calloc(num2,sizeof(float));
  102.   for (int i = 0; i < num2; ++i)
  103.   {
  104.   for (int j = 0; j < num2; ++j)
  105.   {
  106.   for (int k = 0; k < dim2; ++k)
  107.   {
  108.   aqkn[i*num2+j] += (qn[i*dim2+k]*kn[j*dim2+k]);
  109.   }
  110.   aqkn[i*num2+j] = aqkn[i*num2+j] / sqrt(dim2);
  111.   sumaqkn[i] += exp(aqkn[i*num2+j]);
  112.   }
  113.   }
  114.  
  115.   for (int i = 0; i < num2; i++)
  116.   {
  117.   for (int j = 0; j < num2; j++)
  118.   {
  119.   alfa[i*num2+j] = exp(aqkn[i*num2+j])/sumaqkn[i];
  120.   }
  121.   }
  122.  
  123.   for (int i = 0; i < num2; i++)
  124.   {
  125.   for (int j = 0; j < dim2; j++)
  126.   {
  127.   for (int k = 0; k < num2; k++)
  128.   {
  129.   c[i*dim2+j] += (alfa[i*num2+k]*vn[k*dim2+j]);
  130.   }
  131.   }
  132.   }
  133.   printf("%f\n", t1-t0);
  134.  
  135.   if(argc == 2){
  136. for(int i = 0; i < num2;i++){
  137. for(int j = 0; j < dim2;j++){
  138. printf("%.1f\t",c[i*dim2+j]);
  139. }
  140. printf("\n");
  141. }
  142.   }*/
  143.  
  144. /*free(X);
  145.   free(Wk);
  146.   free(Wq);
  147.   free(Wv);
  148.   free(Bk);
  149.   free(Bq);
  150.   free(Bv);
  151.   free(kn);
  152.   free(vn);
  153.   free(qn);
  154.   free(aqkn);
  155.   free(alfa);
  156.   free(c);
  157.   free(sumaqkn);*/
  158.  
  159. return 0;
  160. }
  161.  
Success #stdin #stdout 0s 5304KB
stdin
Standard input is empty
stdout
0.0	13.0	3.0	16.0	0.0	0.0	
6.0	19.0	9.0	22.0	0.0	0.0	
12.0	2.0	15.0	5.0	0.0	0.0	
18.0	8.0	21.0	11.0	0.0	0.0	

-8.2	-4.6	-1.0	2.6	
-9.0	-5.0	-1.0	3.0	
-9.8	-5.4	-1.0	3.4	
-10.6	-5.8	-1.0	3.8	
-11.4	-6.2	-1.0	4.2	
-12.2	-6.6	-1.0	4.6	

1.3	1.3	-34.7	-40.7	
-46.9	-52.9	1.1	1.1	
0.9	0.9	0.9	0.9	
0.7	0.7	0.7	0.7	
0.5	0.5	0.5	0.5	
0.3	0.3	0.3	0.3