import java.io.* ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.Collections ;
import java.util.Comparator ;
import java.util.HashMap ;
import java.util.StringTokenizer ;
public class Main{
public static final int mod = 1000000007 ;
public static ArrayList< Integer> adj[ ] ;
public static int w[ ] ,a[ ] ,st[ ] ,et[ ] ,par[ ] [ ] ,depth[ ] ,occ[ ] ,cnt[ ] ,time,sum;
public static int N = 40011 ;
// TODO Auto-generated method stub
int n = in.nextInt ( ) ;
int m = in.nextInt ( ) ;
sum = time = 0 ;
w = new int [ N] ;
st = new int [ N] ;
et = new int [ N] ;
a = new int [ 2 * N] ;
par = new int [ N] [ 16 ] ;
depth = new int [ N] ;
int block
= ( int ) Math .
sqrt ( n
) ;
for ( int i= 0 ; i< N; i++ )
adj[ i] = new ArrayList< Integer> ( ) ;
int no = 0 ;
for ( int i= 1 ; i<= n; i++ )
{
int tmp = in.nextInt ( ) ;
if ( ! hash.containsKey ( tmp) )
{
hash.put ( tmp,++ no) ;
}
w[ i] = hash.get ( tmp) ;
}
for ( int i= 0 ; i< n- 1 ; i++ )
{
int u = in.nextInt ( ) ;
int v = in.nextInt ( ) ;
adj[ u] .add ( v) ;
adj[ v] .add ( u) ;
}
for ( int i= 0 ; i< 16 ; i++ )
for ( int j= 1 ; j<= n; j++ )
par[ j] [ i] = - 1 ;
depth[ 1 ] = 0 ;
dfs( 1 ,0 ) ;
for ( int i= 1 ; i< 16 ; i++ )
for ( int j= 1 ; j<= n; j++ )
{
if ( par[ j] [ i- 1 ] != - 1 )
par[ j] [ i] = par[ par[ j] [ i- 1 ] ] [ i- 1 ] ;
}
Query q[ ] = new Query[ m] ;
int ans[ ] = new int [ m] ;
for ( int i= 0 ; i< m; i++ )
{
int u = in.nextInt ( ) ;
int v = in.nextInt ( ) ;
if ( depth [ u] < depth[ v] )
{
int tmp = u;
u = v;
v = tmp;
}
int lc = lca( u, v) ;
if ( lc == v)
{
q[ i] = new Query( st[ v] ,st[ u] + 1 ,i,- 1 ) ;
}
else {
if ( st[ v] > et[ u] )
{
q[ i] = new Query( et[ u] ,st[ v] + 1 ,i,lc) ;
}
else
{
q[ i] = new Query( et[ v] ,st[ u] + 1 ,i,lc) ;
}
}
}
Arrays .
sort ( q,
new Comparator
< Query
> ( ) {
public int compare( Query q1,Query q2)
{
if ( q1.l / block < q2.l / block)
return - 1 ;
else if ( q1.l / block > q2.l / block)
return 1 ;
else
return q1.r - q2.r ;
}
} ) ;
occ = new int [ N] ;
cnt = new int [ N] ;
sum = 0 ;
int L = 0 , R = 0 ;
StringBuilder sb = new StringBuilder( "" ) ;
for ( int i= 0 ; i< m; i++ )
{
int nextL = q[ i] .l ;
int nextR = q[ i] .r ;
while ( L < nextL)
{
remove( a[ L] ) ;
L++;
}
while ( L > nextL)
{
add( a[ L- 1 ] ) ;
L--;
}
while ( R < nextR)
{
add( a[ R] ) ;
R++;
}
while ( R > nextR )
{
remove( a[ R- 1 ] ) ;
R--;
}
if ( q[ i] .lca != - 1 && cnt[ w[ q[ i] .lca ] ] == 0 )
{
ans[ q[ i] .i ] = sum+ 1 ;
}
else
{
ans[ q[ i] .i ] = sum;
}
}
for ( int i= 0 ; i< m; i++ )
sb.append ( ans[ i] + "\n " ) ;
System .
out .
println ( sb.
toString ( ) ) ;
}
static void remove( int x)
{
int wt = w[ x] ;
occ[ x] --;
if ( occ[ x] == 1 ) {
cnt[ wt] ++;
if ( cnt[ wt] == 1 )
sum++;
return ;
}
cnt[ wt] --;
if ( cnt[ wt] == 0 ) sum--;
}
static void add( int x)
{
occ[ x] ++;
cnt[ w[ x] ] ++;
if ( occ[ x] == 2 ) {
cnt[ w[ x] ] -= 2 ;
if ( cnt[ w[ x] ] == 0 )
sum--;
}
else if ( cnt[ w[ x] ] == 1 ) sum++;
}
static class Query{
int l,r,i,lca;
Query( int l,int r,int i,int lca)
{
this .lca = lca;
this .l = l;
this .r = r;
this .i = i;
}
}
static int lca( int u,int v)
{
int lg, i;
for ( lg = 0 ; ( 1 << lg) <= depth[ u] ; lg++ ) ;
lg--;
for ( i= lg; i>= 0 ; i-- )
if ( depth[ u] - ( 1 << i) >= depth[ v] )
u = par[ u] [ i] ;
if ( u == v)
return u;
for ( i = lg; i >= 0 ; i-- ) {
if ( par[ u] [ i] != - 1 && par[ u] [ i] != par[ v] [ i] )
{ u = par[ u] [ i] ; v = par[ v] [ i] ; }
}
return par[ u] [ 0 ] ;
}
static void dfs( int curr,int prev)
{
st[ curr] = ++ time;
a[ time] = curr;
for ( int i= 0 ; i< adj[ curr] .size ( ) ; i++ )
{
int next = adj[ curr] .get ( i) ;
if ( next != prev)
{
depth[ next] = depth[ curr] + 1 ;
par[ next] [ 0 ] = curr;
dfs( next,curr) ;
}
}
et[ curr] = ++ time;
a[ time] = curr;
}
static double dist( double x1,double y1,double x2,double y2)
{
return Math .
sqrt ( ( x1
- x2
) * ( x1
- x2
) + ( y1
- y2
) * ( y1
- y2
) ) ;
}
static long modpowIter( long a, long b, long c) {
long ans= 1 ;
while ( b != 0 ) {
if ( b% 2 == 1 )
ans= ( ans* a) % c;
a= ( a* a) % c;
b /= 2 ;
}
return ans;
}
static int gcd( int a,int b)
{
if ( b == 0 )
return a;
else return gcd( b,a% b) ;
}
static int lcm( int a, int b)
{
return ( a* b) / gcd( a,b) ;
}
{
final private int BUFFER_SIZE = 1 << 16 ;
private byte [ ] buffer;
private int bufferPointer, bytesRead;
{
buffer = new byte [ BUFFER_SIZE] ;
bufferPointer = bytesRead = 0 ;
}
{
buffer = new byte [ BUFFER_SIZE] ;
bufferPointer = bytesRead = 0 ;
}
{
byte [ ] buf = new byte [ 64 ] ; // line length
int cnt = 0 , c;
while ( ( c = read( ) ) != - 1 )
{
if ( c == '\n ' )
break ;
buf[ cnt++ ] = ( byte ) c;
}
return new String ( buf,
0 , cnt
) ; }
{
int ret = 0 ;
byte c = read( ) ;
while ( c <= ' ' )
c = read( ) ;
boolean neg = ( c == '-' ) ;
if ( neg)
c = read( ) ;
do
{
ret = ret * 10 + c - '0' ;
} while ( ( c = read( ) ) >= '0' && c <= '9' ) ;
if ( neg)
return - ret;
return ret;
}
{
long ret = 0 ;
byte c = read( ) ;
while ( c <= ' ' )
c = read( ) ;
boolean neg = ( c == '-' ) ;
if ( neg)
c = read( ) ;
do {
ret = ret * 10 + c - '0' ;
}
while ( ( c = read( ) ) >= '0' && c <= '9' ) ;
if ( neg)
return - ret;
return ret;
}
{
double ret = 0 , div = 1 ;
byte c = read( ) ;
while ( c <= ' ' )
c = read( ) ;
boolean neg = ( c == '-' ) ;
if ( neg)
c = read( ) ;
do {
ret = ret * 10 + c - '0' ;
}
while ( ( c = read( ) ) >= '0' && c <= '9' ) ;
if ( c == '.' )
{
while ( ( c = read( ) ) >= '0' && c <= '9' )
{
ret += ( c - '0' ) / ( div *= 10 ) ;
}
}
if ( neg)
return - ret;
return ret;
}
{
bytesRead = din.read ( buffer, bufferPointer = 0 , BUFFER_SIZE) ;
if ( bytesRead == - 1 )
buffer[ 0 ] = - 1 ;
}
{
if ( bufferPointer == bytesRead)
fillBuffer( ) ;
return buffer[ bufferPointer++ ] ;
}
{
if ( din == null )
return ;
din.close ( ) ;
}
}
}
