bzoj2243染色 树链剖分+线段树

来源:互联网 发布:指南针软件怎么使用 编辑:程序博客网 时间:2024/06/06 17:16

2243: [SDOI2011]染色

Time Limit: 20 Sec  Memory Limit: 512 MB
Submit: 8230  Solved: 3073
[Submit][Status][Discuss]

Description

给定一棵有n个节点的无根树和m个操作,操作有2类:

1、将节点a到节点b路径上所有点都染成颜色c

2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“1122213段组成:“11、“222和“1

请你写一个程序依次完成这m个操作。

Input

第一行包含2个整数nm,分别表示节点数和操作数;

第二行包含n个正整数表示n个节点的初始颜色

下面 行每行包含两个整数xy,表示xy之间有一条无向边。

下面 行每行描述一个操作:

“C a b c”表示这是一个染色操作,把节点a到节点b路径上所有点(包括ab)都染成颜色c

“Q a b”表示这是一个询问操作,询问节点a到节点b(包括ab)路径上的颜色段数量。

Output

对于每个询问操作,输出一行答案。

Sample Input

6 5

2 2 1 2 1 1

1 2

1 3

2 4

2 5

2 6

Q 3 5

C 2 1 1

Q 3 5

C 5 1 2

Q 3 5

Sample Output

3

1

2

HINT

数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。

Source

第一轮day1


树链剖分,用线段树维护每段上的颜色数量以及左右端点上的颜色,合并时看左右端点颜色是否相同,不同左右相加否则减一
然后区间修改
注意跳轻链的时候要特判一下链头和跳过去的节点颜色是否相同,如果相同要把答案减一。判的时候用线段树判,因为这是修改过后的线段树

#include<iostream>#include<cstdlib>#include<cstdio>#include<cstring>#include<algorithm>#include<map>#include<cmath>#define maxn 101000#define inf 0x3f3f3f3f#define ls p << 1#define rs p << 1 | 1using namespace std;int read(){    char ch = getchar(); int x = 0, f = 1;    while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}    while(ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}    return x * f;} struct tree {    int l, r, cnt, lc, rc;    bool tag;    tree() : l(0), r(0), cnt(0), lc(-1), rc(-1), tag(false) {}}t[maxn * 10]; int pre[maxn], top;struct edge {    int to, next;    void add(int a, int b) {        to = b;        next = pre[a];        pre[a] = top++;    }}e[maxn * 2];void adds(int u, int v){    e[top].add(u, v);    e[top].add(v, u);}int n, m, tot, sz;int w[maxn], f[maxn], d[maxn], pos[maxn], deep[maxn], son[maxn], fw[maxn]; int lca(int u, int v){    while(d[u] != d[v])    {        if(deep[d[u]] < deep[d[v]]) swap(u, v);        u = f[d[u]];    }    if(deep[u] > deep[v]) swap(u, v);    return u;} void dfs1(int u, int fa){    son[u] = 1; f[u] = fa; deep[u] = deep[fa] + 1;    for(int i = pre[u]; ~i; i = e[i].next)    {        int v = e[i].to;        if(v == fa) continue;        dfs1(v, u);        son[u] += son[v];    }} void dfs2(int u, int chain){    pos[u] = ++tot; d[u] = chain; int k = 0;    for(int i = pre[u]; ~i; i = e[i].next)    {        int v = e[i].to;        if(v == f[u]) continue;        if(son[v] > son[k]) k = v;    }    if(!k) return;    dfs2(k, chain);    for(int i = pre[u]; ~i; i = e[i].next)    {        int v = e[i].to;        if(v != f[u] && v != k)            dfs2(v, v);    }} void update(int p) {    t[p].cnt = t[ls].cnt + t[rs].cnt;    if(t[ls].rc == t[rs].lc) --t[p].cnt;    t[p].lc = t[ls].lc;    t[p].rc = t[rs].rc;} void build_tree(int p, int L, int R) {    t[p].l = L; t[p].r = R; t[p].lc = w[L]; t[p].rc = w[R];    if(L == R) {        t[p].cnt = 1;        return;    }    int mid = L + R >> 1;    build_tree(ls, L, mid);    build_tree(rs, mid + 1, R);    update(p);} void init() {    n = read(); m = read(); memset(pre, -1, sizeof(pre)); tot = top = 0;    for(int i = 1;i <= n; ++i) fw[i] = read();    for(int i = 1;i < n; ++i) adds(read(), read());    dfs1(1, 0);    dfs2(1, 1);    for(int i = 1;i <= n; ++i) w[pos[i]] = fw[i];    build_tree(1, 1, n);} void paint(int p, int val){    t[p].lc = t[p].rc = val;    t[p].cnt = 1; t[p].tag = true;} void pushdown(int p) {    if(t[p].tag) {        paint(ls, t[p].lc);        paint(rs, t[p].lc);        t[p].tag = false;    }} void Seg_ch(int p, int st, int ed, int val) {    int l = t[p].l, r = t[p].r;    if(st == l && ed == r) {        paint(p, val);        return;    }    int mid = l + r >> 1;    pushdown(p);    if(st <= mid) Seg_ch(ls, st, min(mid, ed), val);    if(ed > mid) Seg_ch(rs, max(st, mid + 1), ed, val);    update(p);} void change(int u, int v, int val) {    while(d[u] != d[v]) {        if(deep[d[u]] < deep[d[v]]) swap(u, v);        Seg_ch(1, pos[d[u]], pos[u], val);        u = f[d[u]];    }    if(pos[u] > pos[v]) swap(u, v);    Seg_ch(1, pos[u], pos[v], val);} int Seg_sum(int p, int st, int ed) {    int l = t[p].l, r = t[p].r;    //cout<<p<<" "<<t[p].cnt<<" "<<st<<" "<<ed<<" "<<l<<" "<<r<<" "<<endl;    if(st == l && ed == r) return t[p].cnt;    int mid = l + r >> 1;    pushdown(p);    if(st > mid) return Seg_sum(rs, st, ed);    if(ed <= mid) return Seg_sum(ls, st, ed);     int ans = Seg_sum(ls, st, mid) + Seg_sum(rs, mid + 1, ed);    if(t[ls].rc == t[rs].lc) --ans;    return ans;} int Seg_col(int p, int pos) {    int l = t[p].l, r = t[p].r;    if(l == r) return t[p].lc;    pushdown(p);    int mid = l + r >> 1;    if(pos <= mid) return Seg_col(ls, pos);    else return Seg_col(rs, pos);} int query(int u, int v){    int ans = 0;    while(d[u] != d[v]) {        if(deep[d[u]] < deep[d[v]]) swap(u, v);        ans += Seg_sum(1, pos[d[u]], pos[u]);        if(Seg_col(1, pos[d[u]]) == Seg_col(1, pos[f[d[u]]])) ans--;          u = f[d[u]];    }    if(pos[u] > pos[v]) swap(u, v);    //cout<<pos[u]<<" "<<pos[v]<<endl;    ans += Seg_sum(1, pos[u], pos[v]);    return ans;} void solve() {    for(int i = 1;i <= m; ++i) {        char ch = getchar();        while(ch != 'Q' && ch != 'C') ch = getchar();        if(ch == 'C') {            int u = read(), v = read(), val = read();            change(u, v, val);        }        else {            int u = read(), v = read();            printf("%d\n", query(u, v));        }    }} int main(){    init();    solve();    return 0;}/*6 52 2 1 2 1 11 21 32 42 52 6C 4 5 3Q 5 3C 4 6 5Q 2 6Q 4 3*/


阅读全文
0 0
原创粉丝点击