Codeforces #169 Div2 E "Little Girl and Problem on Trees"
問題
一応図解
グラフは
こんな感じになる。
クエリは任意の範囲に値を追加する操作があるので、
BIT(BinaryIndexedTree)たんに登場してもらう。
こんな感じで各グループとグループ共通のBITを用意する。
クエリ0(追加クエリ)が飛んできた時は
こんな感じで
もし、頂点1にクエリが飛んできた場合、
それはすべてのグループの深さdまで値を追加する事になるので、
共通BITに追加する
BITの範囲を超えるようなクエリはどうするか
後ろはただ範囲に収まるように切れば大丈夫。
前にはみ出た分は頂点1にその分だけ追加した時と同じ効果になる。
なのではみ出し分は共通BITの方に追加して、
残ったのはグループのBITに追加するようにする。
クエリ1(値を返せクエリ)は簡単で
指定された頂点の深さを確認し、
共通BITとグループのBITの同じ深さの値を足したものとなる。
ソースコード
※深さとか上の図解と合ってない
// BITクラス宣言 class BinaryIndexedTree { public: BinaryIndexedTree(); ~BinaryIndexedTree(); void init( int size ); int elementNumber(); void add( int a, int b, int x ); long long sum( int a, int b ) const; private: void add( int a, int b, int x, int i, int l, int r ); long long sum( int a, int b, int i, int l, int r ) const; int m_ElementNumber; int m_ArrayCapacity; long long* m_tree_a; long long* m_tree_b; }; // プロトタイプ関数 void dfs( int v, int pv, int bitId, int depth ); void query0( int v, int x, int d ); int query1( int v ); const int MAX_N = 100050; // グラフ vector<int> G[MAX_N]; // 各頂点が属するBITとその深さ int bitID[MAX_N]; int bitDepth[MAX_N]; // 頂点1用のBITとそれ以外のBIT BinaryIndexedTree bit_V1; BinaryIndexedTree bits[MAX_N]; int main(){ int N, Q; cin >> N >> Q; // グラフ初期化 for( int i = 0; i < N - 1; i++ ){ int v1, v2; cin >> v1 >> v2; G[v1].push_back(v2); G[v2].push_back(v1); } // BITを初期化 bit_V1.init( MAX_N ); int n = G[1].size(); int bitid = 0; for( int i = 0; i < n; i++ ){ int v = G[1][i]; dfs( v, 1, bitid, 0 ); bitid++; } // クエリ処理 for( int i = 0; i < Q; i++ ){ int op, v, x, d; cin >> op; if( op == 0 ){ cin >> v >> x >> d; query0( v, x, d ); }else{ cin >> v; printf( "%d\n", query1( v ) ); } } return 0; } // DFSで頂点の属するBITとその深さを初期化する void dfs( int v, int pv, int bitId, int depth ){ bitID[v] = bitId; bitDepth[v] = depth; if( G[v].size() == 1 ){ bits[bitId].init( depth + 1 ); }else{ int nv = G[v][0] != pv ? G[v][0] : G[v][1]; dfs( nv, v, bitId, depth+1 ); } } // クエリ0 void query0( int v, int x, int d ){ if( v == 1 ){ // 頂点1であれば、頂点1用BITに追加 bit_V1.add( 0, d+1, x ); }else{ int bitid = bitID[v]; int depth = bitDepth[v]; int exceed = d - depth - 1; if( exceed >= 0 ){ // はみ出した分を頂点1用BITに追加 bit_V1.add( 0, exceed+1, x ); int a = exceed; int b = depth + d; if( b > bits[bitid].elementNumber() ) b = bits[bitid].elementNumber(); if( a <= b ) bits[bitid].add( a, b+1, x ); }else{ int a = depth - d; int b = depth + d; if( b > bits[bitid].elementNumber() ) b = bits[bitid].elementNumber(); bits[bitid].add( a, b+1, x ); } } } // クエリ1 int query1( int v ){ long long res; if( v == 1 ){ res = bit_V1.sum( 0, 1 ); }else{ int bitid = bitID[v]; int depth = bitDepth[v]; res = bit_V1.sum( depth+1, depth+2 ); res += bits[bitid].sum( depth, depth+1 ); } return (int)res; } // BITクラス定義 BinaryIndexedTree::BinaryIndexedTree() { } BinaryIndexedTree::~BinaryIndexedTree() { delete[] m_tree_a; m_tree_a = 0; delete[] m_tree_b; m_tree_b = 0; } void BinaryIndexedTree::init( int elementNumber ) { m_ElementNumber = elementNumber; int power = 0; while( elementNumber != 0 ){ elementNumber >>= 1; power++; } m_ArrayCapacity = 1 << (power + 1); m_tree_a = new long long[m_ArrayCapacity]; m_tree_b = new long long[m_ArrayCapacity]; for( int i = 0; i < m_ArrayCapacity; i++ ){ m_tree_a[i] = m_tree_b[i] = 0; } } int BinaryIndexedTree::elementNumber() { return m_ElementNumber; } void BinaryIndexedTree::add( int a, int b, int x ) { add( a, b, x, 0, 0, m_ElementNumber ); } long long BinaryIndexedTree::sum( int a, int b ) const { return sum( a, b, 0, 0, m_ElementNumber ); } void BinaryIndexedTree::add( int a, int b, int x, int i, int l, int r ) { if( a <= l && r <= b ){ m_tree_a[i] += x; }else if( a < r && l < b ){ m_tree_b[i] += (long long)x * (min( b, r ) - max( a, l )); add( a, b, x, i*2+1, l, (l+r)/2 ); add( a, b, x, i*2+2, (l+r)/2, r ); } } long long BinaryIndexedTree::sum( int a, int b, int i, int l, int r ) const { long long res = 0; if( a <= l && r <= b ){ res = m_tree_a[i] * (r - l) + m_tree_b[i]; }else if( a < r && l < b ){ res += m_tree_a[i] * (min( b, r ) - max( a, l )); res += sum( a, b, i*2+1, l, (l+r)/2 ); res += sum( a, b, i*2+2, (l+r)/2, r ); } return res; }