import java.util.*;
import java.lang.*;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.BufferedInputStream;
// import java.io.BufferedOutputStream;
// import java.io.PrintWriter;
class Main
{
static Node[] nodes ;
static Query[] queries ;
static int N ;
static int M ;
static int[] lvl ;
static int MAXLVL = 18;
static int[][] dp ;
static int[] ans ;
static int[] dt ;
static int[] ID ;
static int MAXWT ;
static int time ;
static boolean[] vis ;
static int[] BL ;
static int blockSize;
static int lnk ;
static int wei;
static int[] A ;
static Node nd;
static boolean vi ;
static class Node{
int i ;
int st;
int en ;
List<Integer> links = null;
Node(int i){
this.i = i ;
}
public void add(int link){
if(links == null){
this.links= new ArrayList<Integer>(1) ;
}
this.links.add(link);
}
return "i = "+i + " st = "+st+" en = "+en+" \n" ;
}
}
static int lca(int u, int v){
if(lvl[u] > lvl[v]){
u = u+v ;
v = u-v;
u = u-v ;
}
int diff = lvl[v]-lvl[u];
for(int i=0 ;i<MAXLVL ;i++){
if( ((diff >> i)&1) == 1 ){
v= dp[v][i];
}
}
if(v ==u){
return u ;
}
for(int i= MAXLVL-1 ;i>=0;i--){
if(dp[u][i] != dp[v][i]){
u= dp[u][i] ;
v = dp[v][i];
}
}
return dp[u][0] ;
}
static class Query implements Comparable<Query>{
int id ;
int el;
int er ;
int lca ;
public Query(int id, int lca){
this.lca = lca;
this.id = id ;
}
public int compareTo(Query q2){
if(((this.el-1)/blockSize) != ((q2.el-1)/blockSize)){
return this.el- q2.el ;
}else{
return this.er-q2.er ;
}
}
return "id = "+id+" el="+el+" er="+er+" lca="+lca+"\n" ;
}
}
static void dfs(int ind, int level) {
if(vis[ind]) {
return ;
}
vis[ind] = true ;
lvl[ind] = level+1;
nodes[ind].st = ++time ;
ID[time] = ind ;
List<Integer> links = nodes[ind].links;
if(links != null) {
for(int i =0 ; i< links.size();i++) {
lnk = links.get(i) ;
if(!vis[lnk]) {
dp[lnk][0] = ind ;
dfs(lnk, lvl[ind]);
}
}
}
nodes[ind].en = ++time ;
ID[time] = ind;
}
static void flattenTree(){
int time = 1;
Stack<Integer> stack = new Stack<Integer>() ;
stack.push(1) ;
int curr ;
boolean[] vis =new boolean[nodes.length];
lvl[1] = 0 ;
dp[1][0] = 0 ;
while(stack.size() != 0){
curr = stack.peek() ;
if(vis[curr]){
nodes[curr].en = time ;
ID[time++] = curr ;
stack.pop() ;
}else{
nodes[curr].st = time ;
ID[time++] = curr ;
vis[curr] = true ;
}
List<Integer> links = nodes[curr].links ;
for(int i=0 ;i< links.size() ;i++){
int lnk = links.get(i);
if(!vis[lnk]){
dp[lnk][0] = curr;
lvl[lnk] = lvl[curr]+1 ;
stack.push(lnk) ;
}
}
}
}
static void computeSparseMatrix(){
for(int i= 1 ;i<MAXLVL;i++){
for(int j=1 ;j< N;j++){
// System.out.println("j = "+j+" i= "+i) ;
if(dp[j][i-1] != 0){
dp[j][i] = dp[dp[j][i-1]][i-1] ;
}
}
// zprint(dp);
}
}
int curL = queries[0].el ;
int curR = curL-1;
int res = 0 ;
vis =new boolean[N+1] ;
dt = new int[MAXWT+1];
for(int i= 0 ;i<M;i++){
while(curL<queries[i].el){
res = countdt(ID[curL++], res) ;
}
while(curL > queries[i].el){
res = countdt(ID[--curL], res) ;
}
while(curR<queries[i].er){
res = countdt(ID[++curR], res) ;
}
while(curR>queries[i].er){
res = countdt(ID[curR--], res) ;
}
// int l = ID[curL] ;
// int r = ID[curR] ;
if(queries[i].lca != ID[curL] && queries[i].lca != ID[curR]){
res = countdt(queries[i].lca, res) ;
}
ans[queries[i].id] = res ;
if(queries[i].lca != ID[curL] && queries[i].lca != ID[curR]){
res = countdt(queries[i].lca, res) ;
}
// System.out.println(queries[i]) ;
}
// StringBuilder output = new StringBuilder() ;
// PrintWriter out = new PrintWriter(new BufferedOutputStream(System.out));
for(int i = 0 ;i<M ;i++){
// output.append(ans[i]+"\n") ;
// bw.write(ans[i]+"");
// bw.newLine();
// out.println(ans[i]);
}
// out.flush() ;
// System.out.print(output) ;
}
static int countdt(int i, int res) {
// if(true) {
// return 1 ;
// }
wei = A[i] ;
vi = vis[i];
if(vi && (--dt[wei] == 0)){
res-- ;
}else if(!vi && (dt[wei]++ == 0) ){
res++ ;
}
vis[i] = vi^true;
return res ;
}
{
long startTime
= System.
currentTimeMillis() ; int bytes= 1<<4 ;
N = br.nextInt()+1 ;
M = br.nextInt() ;
nodes= new Node[N] ;
queries = new Query[M] ;
vis =new boolean[N] ;
dp = new int[N][MAXLVL] ;
lvl = new int[N] ;
Node nd ;
ID = new int[2*N-1] ;
MAXWT = 0 ;
Map
<Integer, Integer
> convert
=new HashMap
<Integer,Integer
>(N
+100,1f
) ; int counter = 0 ;
A = new int[N] ;
int inp = 0 ;
int w = 0 ;
for(int i = 1 ;i< N ;i++){
w = br.nextInt() ;
if(convert.containsKey(w)) {
A[i] = convert.get(w) ; ;
}else {
convert.put(w, ++counter) ;
MAXWT = counter ;
A[i] = counter ;
}
nodes[i] = new Node(i) ;
}
int u ;
int v ;
for(int i = 0 ;i< N-2 ;i++){
u = br.nextInt() ;
v = br.nextInt() ;
nodes[u].add(v) ;
nodes[v].add(u) ;
}
time = 0 ;
dp[1][0] = 1;
dfs(1,0) ;
blockSize = 600 ;
computeSparseMatrix() ;
for(int i= 0 ;i<M;i++){
u = br.nextInt() ;
v = br.nextInt() ;
if(nodes[u].st > nodes[v].st){
u = u+v;
v = u-v ;
u = u-v ;
}
int lc = lca(u,v) ;
Query query = new Query(i,lc) ;
if(lc == u){
query.el = nodes[u].st ;
query.er = nodes[v].st ;
}else{
query.el = nodes[u].en;
query.er = nodes[v].st ;
}
queries[i] = query ;
}
br.close();
ans = new int[M] ;
moAlgo();
}
static void print(int[][] dp) {
for(int i= 0 ;i<dp.length ;i++) {
}
}
{
final private int BUFFER_SIZE = 1 << 8;
private byte[] buffer;
private int bufferPointer, bytesRead;
{
buffer = new byte[BUFFER_SIZE];
bufferPointer = bytesRead = 0;
}
{
buffer = new byte[BUFFER_SIZE];
bufferPointer = bytesRead = 0;
}
{
byte[] buf = new byte[64]; // line length
int cnt = 0, c;
while ((c = read()) != -1)
{
if (c == '\n')
break;
buf[cnt++] = (byte) c;
}
return new String(buf,
0, cnt
); }
{
int ret = 0;
byte c = read();
while (c <= ' ')
c = read();
boolean neg = (c == '-');
if (neg)
c = read();
do
{
ret = ret * 10 + c - '0';
} while ((c = read()) >= '0' && c <= '9');
if (neg)
return -ret;
return ret;
}
{
long ret = 0;
byte c = read();
while (c <= ' ')
c = read();
boolean neg = (c == '-');
if (neg)
c = read();
do {
ret = ret * 10 + c - '0';
}
while ((c = read()) >= '0' && c <= '9');
if (neg)
return -ret;
return ret;
}
{
double ret = 0, div = 1;
byte c = read();
while (c <= ' ')
c = read();
boolean neg = (c == '-');
if (neg)
c = read();
do {
ret = ret * 10 + c - '0';
}
while ((c = read()) >= '0' && c <= '9');
if (c == '.')
{
while ((c = read()) >= '0' && c <= '9')
{
ret += (c - '0') / (div *= 10);
}
}
if (neg)
return -ret;
return ret;
}
{
bytesRead = din.read(buffer, bufferPointer = 0, BUFFER_SIZE);
if (bytesRead == -1)
buffer[0] = -1;
}
{
if (bufferPointer == bytesRead)
fillBuffer();
return buffer[bufferPointer++];
}
{
if (din == null)
return;
din.close();
}
}
}