import java.io.* ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.Collections ;
import java.util.Comparator ;
import java.util.HashMap ;
import java.util.StringTokenizer ;
public class Main{
public static final int mod = 1000000007 ;
public static ArrayList< Integer> adj[ ] ;
public static int w[ ] ,a[ ] ,st[ ] ,et[ ] ,par[ ] [ ] ,depth[ ] ,occ[ ] ,cnt[ ] ,time,sum;
public static int N = 40011 ;
// TODO Auto-generated method stub
int n = in.nextInt ( ) ;
int m = in.nextInt ( ) ;
sum = time = 0 ;
w = new int [ N] ;
st = new int [ N] ;
et = new int [ N] ;
a = new int [ 2 * N] ;
par = new int [ N] [ 16 ] ;
depth = new int [ N] ;
int block
= ( int ) Math .
sqrt ( n
) ;
for ( int i= 0 ; i< N; i++ )
adj[ i] = new ArrayList< Integer> ( ) ;
int no = 0 ;
for ( int i= 1 ; i<= n; i++ )
{
int tmp = in.nextInt ( ) ;
if ( ! hash.containsKey ( tmp) )
{
hash.put ( tmp,++ no) ;
}
w[ i] = hash.get ( tmp) ;
}
for ( int i= 0 ; i< n- 1 ; i++ )
{
int u = in.nextInt ( ) ;
int v = in.nextInt ( ) ;
adj[ u] .add ( v) ;
adj[ v] .add ( u) ;
}
for ( int i= 0 ; i< 16 ; i++ )
for ( int j= 1 ; j<= n; j++ )
par[ j] [ i] = - 1 ;
depth[ 1 ] = 0 ;
dfs( 1 ,0 ) ;
for ( int i= 1 ; i< 16 ; i++ )
for ( int j= 1 ; j<= n; j++ )
{
if ( par[ j] [ i- 1 ] != - 1 )
par[ j] [ i] = par[ par[ j] [ i- 1 ] ] [ i- 1 ] ;
}
Query q[ ] = new Query[ m] ;
int ans[ ] = new int [ m] ;
for ( int i= 0 ; i< m; i++ )
{
int u = in.nextInt ( ) ;
int v = in.nextInt ( ) ;
if ( depth [ u] < depth[ v] )
{
int tmp = u;
u = v;
v = tmp;
}
int lc = lca( u, v) ;
if ( lc == v)
{
q[ i] = new Query( st[ v] ,st[ u] + 1 ,i,- 1 ) ;
}
else {
if ( st[ v] > et[ u] )
{
q[ i] = new Query( et[ u] ,st[ v] + 1 ,i,lc) ;
}
else
{
q[ i] = new Query( et[ v] ,st[ u] + 1 ,i,lc) ;
}
}
}
Arrays .
sort ( q,
new Comparator
< Query
> ( ) {
public int compare( Query q1,Query q2)
{
if ( q1.l / block < q2.l / block)
return - 1 ;
else if ( q1.l / block > q2.l / block)
return 1 ;
else
return q1.r - q2.r ;
}
} ) ;
occ = new int [ N] ;
cnt = new int [ N] ;
sum = 0 ;
int L = 0 , R = 0 ;
StringBuilder sb = new StringBuilder( "" ) ;
for ( int i= 0 ; i< m; i++ )
{
int nextL = q[ i] .l ;
int nextR = q[ i] .r ;
while ( L < nextL)
{
remove( a[ L] ) ;
L++;
}
while ( L > nextL)
{
add( a[ L- 1 ] ) ;
L--;
}
while ( R < nextR)
{
add( a[ R] ) ;
R++;
}
while ( R > nextR )
{
remove( a[ R- 1 ] ) ;
R--;
}
if ( q[ i] .lca != - 1 && cnt[ w[ q[ i] .lca ] ] == 0 )
{
ans[ q[ i] .i ] = sum+ 1 ;
}
else
{
ans[ q[ i] .i ] = sum;
}
}
for ( int i= 0 ; i< m; i++ )
sb.append ( ans[ i] + "\n " ) ;
System .
out .
println ( sb.
toString ( ) ) ;
}
static void remove( int x)
{
int wt = w[ x] ;
occ[ x] --;
if ( occ[ x] == 1 ) {
cnt[ wt] ++;
if ( cnt[ wt] == 1 )
sum++;
return ;
}
cnt[ wt] --;
if ( cnt[ wt] == 0 ) sum--;
}
static void add( int x)
{
occ[ x] ++;
cnt[ w[ x] ] ++;
if ( occ[ x] == 2 ) {
cnt[ w[ x] ] -= 2 ;
if ( cnt[ w[ x] ] == 0 )
sum--;
}
else if ( cnt[ w[ x] ] == 1 ) sum++;
}
static class Query{
int l,r,i,lca;
Query( int l,int r,int i,int lca)
{
this .lca = lca;
this .l = l;
this .r = r;
this .i = i;
}
}
static int lca( int u,int v)
{
int lg, i;
for ( lg = 0 ; ( 1 << lg) <= depth[ u] ; lg++ ) ;
lg--;
for ( i= lg; i>= 0 ; i-- )
if ( depth[ u] - ( 1 << i) >= depth[ v] )
u = par[ u] [ i] ;
if ( u == v)
return u;
for ( i = lg; i >= 0 ; i-- ) {
if ( par[ u] [ i] != - 1 && par[ u] [ i] != par[ v] [ i] )
{ u = par[ u] [ i] ; v = par[ v] [ i] ; }
}
return par[ u] [ 0 ] ;
}
static void dfs( int curr,int prev)
{
st[ curr] = ++ time;
a[ time] = curr;
for ( int i= 0 ; i< adj[ curr] .size ( ) ; i++ )
{
int next = adj[ curr] .get ( i) ;
if ( next != prev)
{
depth[ next] = depth[ curr] + 1 ;
par[ next] [ 0 ] = curr;
dfs( next,curr) ;
}
}
et[ curr] = ++ time;
a[ time] = curr;
}
static double dist( double x1,double y1,double x2,double y2)
{
return Math .
sqrt ( ( x1
- x2
) * ( x1
- x2
) + ( y1
- y2
) * ( y1
- y2
) ) ;
}
static long modpowIter( long a, long b, long c) {
long ans= 1 ;
while ( b != 0 ) {
if ( b% 2 == 1 )
ans= ( ans* a) % c;
a= ( a* a) % c;
b /= 2 ;
}
return ans;
}
static int gcd( int a,int b)
{
if ( b == 0 )
return a;
else return gcd( b,a% b) ;
}
static int lcm( int a, int b)
{
return ( a* b) / gcd( a,b) ;
}
{
final private int BUFFER_SIZE = 1 << 16 ;
private byte [ ] buffer;
private int bufferPointer, bytesRead;
{
buffer = new byte [ BUFFER_SIZE] ;
bufferPointer = bytesRead = 0 ;
}
{
buffer = new byte [ BUFFER_SIZE] ;
bufferPointer = bytesRead = 0 ;
}
{
byte [ ] buf = new byte [ 64 ] ; // line length
int cnt = 0 , c;
while ( ( c = read( ) ) != - 1 )
{
if ( c == '\n ' )
break ;
buf[ cnt++ ] = ( byte ) c;
}
return new String ( buf,
0 , cnt
) ; }
{
int ret = 0 ;
byte c = read( ) ;
while ( c <= ' ' )
c = read( ) ;
boolean neg = ( c == '-' ) ;
if ( neg)
c = read( ) ;
do
{
ret = ret * 10 + c - '0' ;
} while ( ( c = read( ) ) >= '0' && c <= '9' ) ;
if ( neg)
return - ret;
return ret;
}
{
long ret = 0 ;
byte c = read( ) ;
while ( c <= ' ' )
c = read( ) ;
boolean neg = ( c == '-' ) ;
if ( neg)
c = read( ) ;
do {
ret = ret * 10 + c - '0' ;
}
while ( ( c = read( ) ) >= '0' && c <= '9' ) ;
if ( neg)
return - ret;
return ret;
}
{
double ret = 0 , div = 1 ;
byte c = read( ) ;
while ( c <= ' ' )
c = read( ) ;
boolean neg = ( c == '-' ) ;
if ( neg)
c = read( ) ;
do {
ret = ret * 10 + c - '0' ;
}
while ( ( c = read( ) ) >= '0' && c <= '9' ) ;
if ( c == '.' )
{
while ( ( c = read( ) ) >= '0' && c <= '9' )
{
ret += ( c - '0' ) / ( div *= 10 ) ;
}
}
if ( neg)
return - ret;
return ret;
}
{
bytesRead = din.read ( buffer, bufferPointer = 0 , BUFFER_SIZE) ;
if ( bytesRead == - 1 )
buffer[ 0 ] = - 1 ;
}
{
if ( bufferPointer == bytesRead)
fillBuffer( ) ;
return buffer[ bufferPointer++ ] ;
}
{
if ( din == null )
return ;
din.close ( ) ;
}
}
}
import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.StringTokenizer;
 
 
public class Main{
	public static final int mod = 1000000007;
	public static ArrayList<Integer> adj[];
	public static int w[],a[],st[],et[],par[][],depth[],occ[],cnt[],time,sum;
	public static int N = 40011;
	public static void main(String[] args)throws IOException {
		// TODO Auto-generated method stub
		InputStream input = System.in;
		Reader in  = new Reader();
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
	
		int n = in.nextInt();
		int m = in.nextInt();
		
		sum = time = 0;
		adj = new ArrayList[N];
		w = new int[N];
		st = new int[N];
		et = new int[N];
		a = new int[2*N];
		par = new int[N][16];
		depth = new int[N];
		int block = (int)Math.sqrt(n);
		
		HashMap<Integer,Integer> hash = new HashMap<Integer,Integer>();
		
		for(int i=0; i<N; i++)
			adj[i] = new ArrayList<Integer>();
		
		int no = 0;
		for(int i=1; i<=n; i++)
			{
				int tmp = in.nextInt();
				if(!hash.containsKey(tmp))
				{
					hash.put(tmp,++no);
				}
				w[i] = hash.get(tmp);
			}
		
		for(int i=0; i<n-1; i++)
		{
			int u = in.nextInt();
			int v = in.nextInt();
			adj[u].add(v);
			adj[v].add(u);
		}
		
		for(int i=0; i<16; i++)
			for(int j=1; j<=n; j++)
				par[j][i] = -1;
		
		depth[1] = 0;
		dfs(1,0);
		
		for(int i=1; i<16; i++)
			for(int j=1; j<=n; j++)
				{
				if(par[j][i-1] != -1)
					par[j][i] = par[par[j][i-1]][i-1];
				}
		
		
		Query q[] = new Query[m];
		int ans[] = new int[m];
		
		for(int i=0; i<m; i++)
		{
			int u = in.nextInt();
			int v = in.nextInt();
			
			if (depth [u] < depth[v])
				{
					int tmp = u;
					u = v;
					v = tmp;
				}
			int lc = lca(u, v);
			if (lc == v)
			{
				q[i] = new Query(st[v],st[u]+1,i,-1);
			}
			else{
				if (st[v] > et[u])
				{	
					q[i] = new Query(et[u],st[v]+1,i,lc);
				}
				else
				{
					q[i] = new Query(et[v],st[u]+1,i,lc);
				}
			}
			
		}
		
		
		
		Arrays.sort(q,new Comparator<Query>(){
			
			public int compare(Query q1,Query q2)
			{
				if(q1.l/block < q2.l/block)
					return -1;
				else if(q1.l/block > q2.l/block)
					return 1;
				else
					return q1.r-q2.r;
			}
			
		});
		
		occ = new int[N];
		cnt = new int[N];
		sum = 0;

		int L = 0, R = 0;
		StringBuilder sb = new StringBuilder("");
		for(int i=0; i<m; i++)
		{
			
			int nextL = q[i].l;
			int nextR = q[i].r;
			
			while(L < nextL)
			{
				remove(a[L]);
				L++;
			}
			
			while(L > nextL)
			{
				add(a[L-1]);
				L--;
			}
			
			
			while(R < nextR)
			{
				add(a[R]);
				R++;	
			}
			
			while(R > nextR )
			{
				remove(a[R-1]);
				R--;
			}
			
		
			if(q[i].lca != -1 && cnt[w[q[i].lca]] == 0)
			{
				ans[q[i].i] = sum+1;
			}
			else
			{
				ans[q[i].i] = sum;
			}
			
		}
		for(int i=0; i<m; i++)
			sb.append(ans[i]+"\n");
		
		System.out.println(sb.toString());
		
	}
	static void remove(int x)
	{	
		
		int wt = w[x];
		occ[x]--;
		
		if (occ[x] == 1){
			cnt[wt]++;
			if (cnt[wt] == 1)
				sum++;
			return;
		}
		cnt[wt]--;
		if (cnt[wt] == 0) sum--;
		
	}
	static void add(int x)
	{	
		
		occ[x]++;
		cnt[w[x]]++;
		if (occ[x] == 2){
			cnt[w[x]] -= 2;
			if (cnt[w[x]] == 0)
				sum--;
		}
		else if (cnt[w[x]] == 1) sum++;
		
	
	}
	
	static class Query{
		int l,r,i,lca;
		
		Query(int l,int r,int i,int lca)
		{
			this.lca = lca;
			this.l = l;
			this.r = r;
			this.i = i;
		}
	}
	
	static int lca(int u,int v)
	{
		int lg, i;
		for (lg = 0; (1<<lg) <= depth[u]; lg++);
		lg--;
		for(i=lg; i>=0; i--)
			if ( depth[u] - (1<<i) >= depth[v])
				u = par[u][i];
		if (u == v)
			return u;
		for(i = lg; i >= 0; i--){
			if (par[u][i] != -1 && par[u][i] != par[v][i])
			{	u = par[u][i]; v = par[v][i];}
		}
		return par[u][0];
	}
	
	static void dfs(int curr,int prev)
	{
		st[curr] = ++time;
		a[time] = curr;
		
		
		
		for(int i=0; i<adj[curr].size(); i++)
		{
			int next = adj[curr].get(i);
			if(next != prev)
				{
				depth[next] = depth[curr]+1;
				par[next][0] = curr;
				dfs(next,curr);
				}
		}
		
		et[curr] = ++time;
		a[time] = curr;
		
	}
	
	
	static double dist(double x1,double y1,double x2,double y2)
	{
		return Math.sqrt((x1-x2)*(x1-x2) + (y1-y2)*(y1-y2));
		
	}
	static long modpowIter(long a, long b, long c) {
        long ans=1;
        while(b != 0) {
                if(b%2 == 1)
                        ans=(ans*a)%c;
 
                a=(a*a)%c;
                b /= 2;
        }
        return ans;
}
 
static int gcd(int a,int b)
{
	if(b == 0)
		return a;
	else return gcd(b,a%b);
}
static int lcm(int a, int b)
{
	
	return (a*b)/gcd(a,b);
}
	 static class Reader
	    {
	        final private int BUFFER_SIZE = 1 << 16;
	        private DataInputStream din;
	        private byte[] buffer;
	        private int bufferPointer, bytesRead;
	 
	        public Reader()
	        {
	            din = new DataInputStream(System.in);
	            buffer = new byte[BUFFER_SIZE];
	            bufferPointer = bytesRead = 0;
	        }
	 
	        public Reader(String file_name) throws IOException
	        {
	            din = new DataInputStream(new FileInputStream(file_name));
	            buffer = new byte[BUFFER_SIZE];
	            bufferPointer = bytesRead = 0;
	        }
	 
	        public String readLine() throws IOException
	        {
	            byte[] buf = new byte[64]; // line length
	            int cnt = 0, c;
	            while ((c = read()) != -1)
	            {
	                if (c == '\n')
	                    break;
	                buf[cnt++] = (byte) c;
	            }
	            return new String(buf, 0, cnt);
	        }
	 
	        public int nextInt() throws IOException
	        {
	            int ret = 0;
	            byte c = read();
	            while (c <= ' ')
	                c = read();
	            boolean neg = (c == '-');
	            if (neg)
	                c = read();
	            do
	            {
	                ret = ret * 10 + c - '0';
	            }  while ((c = read()) >= '0' && c <= '9');
	 
	            if (neg)
	                return -ret;
	            return ret;
	        }
	 
	        public long nextLong() throws IOException
	        {
	            long ret = 0;
	            byte c = read();
	            while (c <= ' ')
	                c = read();
	            boolean neg = (c == '-');
	            if (neg)
	                c = read();
	            do {
	                ret = ret * 10 + c - '0';
	            }
	            while ((c = read()) >= '0' && c <= '9');
	            if (neg)
	                return -ret;
	            return ret;
	        }
	 
	        public double nextDouble() throws IOException
	        {
	            double ret = 0, div = 1;
	            byte c = read();
	            while (c <= ' ')
	                c = read();
	            boolean neg = (c == '-');
	            if (neg)
	                c = read();
	 
	            do {
	                ret = ret * 10 + c - '0';
	            }
	            while ((c = read()) >= '0' && c <= '9');
	 
	            if (c == '.')
	            {
	                while ((c = read()) >= '0' && c <= '9')
	                {
	                    ret += (c - '0') / (div *= 10);
	                }
	            }
	 
	            if (neg)
	                return -ret;
	            return ret;
	        }
	 
	        private void fillBuffer() throws IOException
	        {
	            bytesRead = din.read(buffer, bufferPointer = 0, BUFFER_SIZE);
	            if (bytesRead == -1)
	                buffer[0] = -1;
	        }
	 
	        private byte read() throws IOException
	        {
	            if (bufferPointer == bytesRead)
	                fillBuffer();
	            return buffer[bufferPointer++];
	        }
	 
	        public void close() throws IOException
	        {
	            if (din == null)
	                return;
	            din.close();
	        }
	    }
}           