#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>
#define NUM 6
#define DIM 4
int main(int argc, char *argv[])
{
int num2 = 6;
int dim2 = 4;
float *X
= (float *) malloc (sizeof(float)*num2
*dim2
); float *Wk
= (float *) malloc (sizeof(float)*dim2
*dim2
); float *Wq
= (float *) malloc (sizeof(float)*dim2
*dim2
); float *Wv
= (float *) malloc (sizeof(float)*dim2
*dim2
); float *Bk
= (float *) malloc (sizeof(float)*dim2
); float *Bq
= (float *) malloc (sizeof(float)*dim2
); float *Bv
= (float *) malloc (sizeof(float)*dim2
);
float *kn
= (float *) calloc (num2
*dim2
,sizeof(float)); float *vn
= (float *) calloc (num2
*dim2
,sizeof(float)); float *qn
= (float *) calloc (num2
*dim2
,sizeof(float)); float *aqkn
= (float *) calloc (num2
*num2
,sizeof(float)); float *alfa
= (float *) malloc (sizeof(float)*num2
*num2
); float *c
= (float *) calloc (num2
*dim2
,sizeof(float));
// ------------------------------------------------------------
// Inicializacion
for (int i = 0; i < num2; i++)
{
for (int j = 0; j < dim2; j++)
{
X[i*dim2+j]= i + 6*j;
}
}
for (int i = 0; i < dim2; i++)
{
for (int j = 0; j < dim2; j++)
{
Wk[j*dim2+i]= -0.2 + 0.1*j;
Wq[j*dim2+i]= -0.2 + 0.1*i;
if(i == j)
{
Wv[i*dim2+j]= 1;
}
else
{
Wv[i*dim2+j]= 0;
}
}
}
for (int i = 0; i < dim2; i++)
{
Bv[i] = 0;
Bq[i] = 0.1;
Bk[i] = -1;
}
// ------------------------------------------------------------
for (int i = 0; i < num2; ++i)
{
for (int j = 0; j < dim2; ++j)
{
for (int k = 0; k < dim2; ++k)
{
vn[i*num2+j] += (Wv[i*dim2+k] * X[j*num2+k]);
kn[i*dim2+j] += (Wk[j*dim2+k] * X[i*dim2+k]);
qn[i*dim2+j] += (Wq[j*dim2+k] * X[i*dim2+k]);
}
vn[j*dim2+i] += Bv[j];
kn[i*dim2+j] += Bk[j];
qn[i*dim2+j] += Bq[j];
}
}
for(int i = 0; i < dim2;i++){
for(int j = 0; j < num2;j++)
printf("%.1f\t",vn
[i
*num2
+j
]); }
for(int i = 0; i < num2;i++){
for(int j = 0; j < dim2;j++)
printf("%.1f\t",kn
[i
*dim2
+j
]); }
for(int i = 0; i < num2;i++){
for(int j = 0; j < dim2;j++)
printf("%.1f\t",qn
[i
*dim2
+j
]); }
/*
float *sumaqkn = (float *) calloc(num2,sizeof(float));
for (int i = 0; i < num2; ++i)
{
for (int j = 0; j < num2; ++j)
{
for (int k = 0; k < dim2; ++k)
{
aqkn[i*num2+j] += (qn[i*dim2+k]*kn[j*dim2+k]);
}
aqkn[i*num2+j] = aqkn[i*num2+j] / sqrt(dim2);
sumaqkn[i] += exp(aqkn[i*num2+j]);
}
}
for (int i = 0; i < num2; i++)
{
for (int j = 0; j < num2; j++)
{
alfa[i*num2+j] = exp(aqkn[i*num2+j])/sumaqkn[i];
}
}
for (int i = 0; i < num2; i++)
{
for (int j = 0; j < dim2; j++)
{
for (int k = 0; k < num2; k++)
{
c[i*dim2+j] += (alfa[i*num2+k]*vn[k*dim2+j]);
}
}
}
printf("%f\n", t1-t0);
if(argc == 2){
for(int i = 0; i < num2;i++){
for(int j = 0; j < dim2;j++){
printf("%.1f\t",c[i*dim2+j]);
}
printf("\n");
}
}*/
/*free(X);
free(Wk);
free(Wq);
free(Wv);
free(Bk);
free(Bq);
free(Bv);
free(kn);
free(vn);
free(qn);
free(aqkn);
free(alfa);
free(c);
free(sumaqkn);*/
return 0;
}