https://www.acmicpc.net/problem/20188
Tag : Tree dp, complement set
두가지 풀이로 풀어봤는데 둘 모두 꽤 교육적인 풀이였다.
sol1) Tree dp
\(sz_i:=\)i를 루트로 가지는 서브트리의 사이즈
\(sum_i:=\)i를 루트로 가지는 서브트리의 모든 원소에서 i로 향하는 경로의 합
이 두 정보를 가지고 있다면 dfs(x)에서 x를 lca로 가지는 두 정점에 대한 처리를 모두 할 수 있다.
1번 정점에서 x로 향하는 경로는 \(\binom{sz_i-1}{2}\)번 사용되며
이것과 \(\sum {dist*sz[i]+cnt*(sum[i]+sz[i])}\)의 합이 x를 lca로 가지는 모든 두 정점들의 답이다.
\(dist\)는 현재까지 살핀 서브트리의 원소들에서 x로 향하는 경로 합, \(cnt\)는 현재까지 살핀 서브트리 크기의 합이다.
시간복잡도 : \(O(N)\)
전체코드
#include "bits/stdc++.h"
#define pb push_back
#define fi first
#define se second
#define all(x) ((x).begin()), ((x).end())
#define compress(x) sort(all(x)), (x).erase(unique(all(x)),(x).end())
#define siz(x) ((int)((x).size()))
#define endl '\n'
using namespace std;
using ll = long long;
using pi = pair<int,int>;
using pl = pair<ll,ll>;
template<typename T>T rmin(T &a,T b){return a=min<T>(a,b);}
template<typename T>T rmax(T &a,T b){return a=max<T>(a,b);}
ll N,ans=0;
ll sz[303030];
ll sum[303030]; // sum[i]:=i서브트리의 모든 정점에서 i로 향하는 거리의 합
ll chk[303030];
vector<ll>v[303030];
void dfs(ll x,ll dep=0){
chk[x]=sz[x]=1;sum[x]=0;
ll dist=0,cnt=0;
ll t=0;
for(auto&i:v[x]){
if(chk[i])continue;
dfs(i,dep+1);
ans+=dep*cnt*sz[i];
ans+=dist*sz[i]+cnt*(sum[i]+sz[i]);
// t+=dep*cnt*sz[i]+dist*sz[i]+cnt*(sum[i]+sz[i]);
sz[x]+=sz[i];
sum[x]+=sum[i]+sz[i];
dist+=sum[i]+sz[i];
cnt+=sz[i];
}
ans+=dep*(sz[x]-1);
ans+=sum[x];
// t+=dep*(sz[x]-1)+sum[x];
// cout<<x<<' '<<t<<endl;
}
int main(){
ios::sync_with_stdio(0);cin.tie(0);
cin>>N;
for(int i=1;i<=N-1;i++){
int x,y;cin>>x>>y;
v[x].pb(y);v[y].pb(x);
}
dfs(1);
cout<<ans<<endl;
}
sol2) complement set
여집합을 이용하면 굉장히 쉽게 풀린다. 아무 간선이나 잡고 이를 \(e\)라 이름 붙여보자
\(e\)에서 부모쪽을 \(s\) 자식을 \(e\)라 한다면 간선 \(e\)는
\(\binom{N}{2}-\binom{N-sz_s}{2}\)개의 경우를 제외하고 모두 사용된다.
이 관찰을 하지 못하면 sol1대로 복잡하게 풀어야한다..
시간복잡도 : \(O(N)\)
전체코드
#include "bits/stdc++.h"
#define pb push_back
#define fi first
#define se second
#define all(x) ((x).begin()), ((x).end())
#define compress(x) sort(all(x)), (x).erase(unique(all(x)),(x).end())
#define siz(x) ((int)((x).size()))
#define endl '\n'
using namespace std;
using ll = long long;
using pi = pair<int,int>;
using pl = pair<ll,ll>;
template<typename T>T rmin(T &a,T b){return a=min<T>(a,b);}
template<typename T>T rmax(T &a,T b){return a=max<T>(a,b);}
ll N,ans=0;
ll sz[303030];
vector<ll>v[303030];
void dfs(ll x){
sz[x]=1;
for(auto&i:v[x]){
if(sz[i])continue;
dfs(i);
sz[x]+=sz[i];
}
}
#include <unistd.h>
#include <sys/stat.h>
#include <sys/mman.h>
int z[36];
char*c=(char*)mmap(0,z[12],1,2,0,fstat(0,(struct stat *)z));
inline int f(){int x=0;bool e;c+=e=*c=='-';while(*c>='0')x=10*x+*c++-'0';c++;return e?-x:x;}
int main(){
ios::sync_with_stdio(0);cin.tie(0);
cin>>N;ans=(N-1)*N*(N-1)/2;
vector<pi>edge(N-1);
for(int i=1;i<=N-1;i++){
int a,b;cin>>a>>b;
edge[i-1]={a,b};
v[a].pb(b);v[b].pb(a);
}
dfs(1);
for(auto&[i,j]:edge){
if(sz[i]<sz[j])swap(i,j);//i -> j
ans-=(N-sz[j])*(N-sz[j]-1)/2;
}
cout<<ans<<endl;
}
'BOJ' 카테고리의 다른 글
[BOJ 11973] Angry Cows (Silver) (0) | 2021.11.25 |
---|---|
[BOJ 8980] 택배 (0) | 2021.11.25 |
[BOJ 5551] 쇼핑몰 (0) | 2021.11.22 |
[BOJ 15732] 도토리 숨기기 (0) | 2021.11.22 |
[BOJ 5837] Poker Hands (0) | 2021.11.22 |