ccf202312-3 树上搜索

发布于 2024-01-23  4014 次阅读


ccf树上搜索

原题链接

这段c++的代码写的很优雅:

#include "bits/stdc++.h"
using namespace std;
#define For(i, j, n) for(int i=j;i<=n;i++)
#define Fol(i, j, n) for(int i=j;i>=n;i--)
typedef long long LL;
const int N = 2e3 + 5;
struct node {
    int w;//自身权重
    LL wt;//自己与后代的权重
    unordered_set<int> son;//孩子们
} tr[N];
bool st[N];//判断这个类别是否被去除
set<int> seg;//待查询的类别集和

//更新以root为根的树上所有节点的wt
LL dfs(int root, set<int> &seg) {
    seg.insert(root);
    LL res = 0;
    for (auto child: tr[root].son) {
        if (st[child]) continue;
        res += dfs(child, seg);
    }
    tr[root].wt = res + tr[root].w;
    return tr[root].wt;
}
//查询wsigma最小的节点
int query(int root, set<int> &seg) {
    LL wmin = LONG_LONG_MAX, pos = -1;
    for (auto x: seg) {
        LL wsigma = abs(tr[root].wt - 2 * tr[x].wt);
        if (wmin > wsigma) {
            wmin = wsigma;
            pos = x;
        }
    }
    return pos;
}
//用户询问判断ch是否被归类fa或者fa的后代
bool judge(int fa, int ch) {
    if (fa == ch) return true;
    bool flag = false;
    for (auto x: tr[fa].son) {
        flag |= judge(x, ch);
        if(flag) break;
    }
    return flag;
}
int main() {
    int n, m, fa;
    scanf("%d %d", &n, &m);
    For(i, 1, n) scanf("%d", &tr[i].w);
    For(i, 2, n) {
        scanf("%d", &fa);
        tr[fa].son.insert(i);
    }
    For(i, 1, m) {
        memset(st, false, sizeof st);
        int root = 1, x;
        scanf("%d", &x);
        while (1) {
            seg.clear();
            dfs(root, seg);
            if (seg.size() == 1)break;//直到只剩下一个类别,此时即可确定名词的类别
            int id = query(root, seg);
            printf("%d ", id);
            if (judge(id, x)) root = id;//如果用户回答是,保留该类别及其后代类别
            else st[id] = true;//否则仅保留其余类别
        }
        puts("");
    }
    return 0;
}

另外贴一个Java的:

也写的不错,不过有少量优化空间:

import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), m = sc.nextInt();
        int[] weigth = new int[n + 1];
        int[] father = new int[n + 1]; // 存储当前节点的father是谁
        List<Integer>[] son =  new List[n + 1]; //存储当前节点的儿子节点有什么
        for (int i = 1; i <= n; i++) {
            weigth[i] = sc.nextInt();
            son[i] = new ArrayList<>();
        }
        for (int i = 2; i <= n; i++) {
            father[i] = sc.nextInt();
            son[father[i]].add(i);
        }
        for (int i = 0; i < m; i++) {
            int j = sc.nextInt();
            int head = 1;
            ArrayList<Integer> removed = new ArrayList<>();
            do{
                long[] currentWeight = new long[n + 1];
                getWeight(weigth, son, removed, head, currentWeight);
                long[] w = new long[n + 1]; //存储最终的权重值,也就是题目要求的权重
                long minValue = Long.MAX_VALUE;
                int minIndex = 0;
                for (int k = 1; k <= n; k++) {
                    if (removed.contains(k))continue;
                    w[k] = Math.abs(2 * currentWeight[k] - currentWeight[head]);// 根据题目要求得出的最终权重公式
                    if (minValue > w[k]){
                        minValue = w[k];
                        minIndex = k;
                    }
                }
                if (isSon(father, minIndex, j, head)){
                    head = minIndex;
                }else {
                    removed.add(minIndex);
                }
                System.out.print(minIndex + " ");
            }while (hasTwo(son, removed, head));
            System.out.println(" ");
        }
    }

    //获取所有节点的权重值,此权重为 它和其全部后代类别的权重之和, 用res承接答案
    public static long getWeight(int[] weight, List<Integer>[] son, List<Integer> removed, int root, long[] res){
        if (removed.contains(root))return 0;
        long ans = weight[root];
        for (Integer temp : son[root]) {
            ans += getWeight(weight, son, removed, temp, res);
        }
        res[root] = ans;
        return ans;
    }

    //判断j节点是否是min节点的孩子
    public static boolean isSon(int[] father, int min ,int j, int head){
        if (min == head)return true;
        while (j != head){
            if (j == min)return true;
            j = father[j];
        }
        return false;
    }

    //判断当前树结构是否有两个及以上节点存在
    public static boolean hasTwo(List<Integer>[] son, List<Integer> removed,int head){
        for (Integer temp : son[head]) {
            if (!removed.contains(temp))return true;
        }
        return false;
    }
}