【C/C++无聊练手(四)】用C++模板写一个带迭代器的树状数组,用于求前缀和
文章目录
- 前言
- 详细代码
- bitree.h
- bitree.cpp
- bitree_test.h
- main.cpp
- 总结
前言
最近突发奇想,想到了一个神奇的场景:假设有 N N N 个人,每个人初始都有 1 1 1 块钱。每次随机给其中一个人 1 1 1 块钱,每个人获得钱的概率和自己当前有的钱数成正比(本来想模拟马太效应),那么最后每个人的钱对应的概率分布是多少?当然,最后仿真结果出来是指数分布,这也很好解释,因为这个过程有无记忆性,但这不是重点,就不多聊了。
这里的编程难点在于如何实现让每个人获得钱的概率和自己当前有的钱数成正比,一个暴力的做法就是每次迭代时求一次前缀和得到概率分布函数,然后在随机一个均匀分布的整数,通过 lower_bound
函数给出应该获得钱的对象的 index
。但显然这样一次迭代的时间复杂度是
O
(
N
)
O\left(N\right)
O(N) ,例如执行
N
N
N 次迭代就要平方复杂度了,实在让人难以接受。
当然,对于固定的数组,可以以
O
(
1
)
O\left(1\right)
O(1) 复杂度求解前缀和(如303. 区域和检索 - 数组不可变),然后再执行一个
O
(
log
N
)
O\left(\log N\right)
O(logN) 复杂度的 lower_bound
。实际上STL库的 discrete_distribution
函数就是这么实现的,但这样的缺点在于不可动态调整,每次调整还是得重新计算一遍前缀和。
说到可动态修改的前缀和,那当然就是大名鼎鼎的线段树和树状数组了(见307. 区域和检索 - 数组可修改)。以树状数组为例,其可以 O ( log N ) O\left(\log N\right) O(logN) 时间复杂度进行修改、求前缀和操作。因此,反正闲着也是闲着,不如拿这玩意练练手,刚好作为C++的复健运动。
当然,既然要做,那就得做到足够好和完善。本项目通过模板类写了一个比较完备的树状数组类,并且按照STL的规范写了一个随机访问求和结果的迭代器,用于与STL算法交互(例如,如果要查找区间 [first, last)
中不小于 val 的第一个元素,可以使用 std::lower_bound(bt.begin(), bt.end( ), val)
就能在
O
(
log
2
N
)
O\left(\log^2N\right)
O(log2N) 时间内找到它)。
本项目被同步到 Github
中,内含详尽的中英文双语注释+README
,链接如下——
https://github.com/LiuZJ2019/BinaryIndexedTree_TemplateClass_And_Iterator
项目功能大致如下demo代码所示——
template<class T>
class BITree {
// 数据成员
private:
std::vector<T> m_tree; // Binary Indexed Tree
std::vector<T> m_arr;
...
// 公有接口
public:
// fun1: 随机访问
virtual T get(size_t index) const;
// fun2: 元素自增
virtual void add(size_t index, T val);
// fun3: 元素更新
virtual void update(size_t index, T val);
// fun4: sum[0:index]
virtual T sum(size_t index) const;
// fun5: sum[left:right]
virtual T sum(size_t left, size_t right) const;
// fun6: m_arr 的长度
virtual size_t size() const;
// fun7: resize
virtual void resize(size_t siz);
// fun8: 获取头迭代器, 可适配于 STL
virtual iterator begin();
// fun9: 获取尾迭代器, 可适配于 STL
virtual iterator end();
// fun10: 获取 m_arr
virtual const std::vector<T> &get_arr() const;
// fun11: 打印输出
virtual void print(std::ostream &os) const;
// 随机访问迭代器
struct sum_iterator : public std::iterator<...> {
...
};
}
// 封装供流式IO调用
template<class T>
std::ostream& operator<< (std::ostream& os, const BITree<T> &packet);
详细代码
核心代码是 bitree.h
和 bitree.cpp
, bitree_test.h
和 main.cpp
是用来测试的。 bitree_test.h
给出了很多例子,用于展示如何使用代码,该项目的代码提供了详细的注释。
bitree.h
懒得动了,直接把代码复制过来——
/**
* @file bitree.h
* @brief Binary Indexed Tree
* @author 间宫羽咲sama
* @note Template implementation of Binary Indexed Tree
*/
#ifndef INC_BINARYINDEXEDTREE_BITREE_H
#define INC_BINARYINDEXEDTREE_BITREE_H
#include <iostream>
#include <vector>
/**
* @brief Binary Indexed Tree
* @details The English notes are as follows:
* Suppose 'L(x)' represents the lowest bit of 'x', '(101100) is the binary representation
* For example, if 'x == 44 == (101100)', then 'L(x) == 4 == (100)'
* Suppose 'T(i)' represents the sum of A from the interval of ( i - L(i), i ]
* For example, if 'i == (101100)', then 'i - L(i) == (101000)', and 'T(i)' is $\sum_{k=i-L(i)+1}^{i} {A_k}$
* In this case, The complexity of the update operation and the sum operation are both $O(log N)$
*
* 1. Sum: O(log N)
* Notice that: (0, (101100)] = ( (0), (100000) ] + ( (100000), (101000) ] + ( (101000), (101100) ]
* The maximum number of operations cannot exceed the number of bits, i.e. O(log N)
* 2. Update: O(log N)
* If you want to modify index (101101), you only need to modify all the intervals containing it,
* the following numbers correspond to ( i - L(i), i ] will contain (101101):
* (101101), (101110), (110000), (1000000), (10000000), (100000000) ...
* But this number cannot exceed the length of the array, at most O(log N)
* 3. Get: O(1)
* copy raw array to make the complexity of the get operation become O(1)
*
*
* The following content is the same as the English version, just a translated version.
* 中文注释如下:
* 记 01 串外面加括号代表数字的二进制表示, 如 (101100) .
* 记 L(x) 为 x 的最低比特, 例如当 'x == 44 == (101100)' 时, 此时 'L(x) == 4 == (100)'
* 记 'T(i)' 是 A 在区间 ( i - L(i), i ] 上的求和
* 例如 'i == (101100)' 时, 有 'i - L(i) == (101000)', 此时 'T(i)' 是 $\sum_{k=i-L(i)+1}^{i} {A_k}$
* 此时'更新'和'求和'操作的复杂度为 $O(log N)$
*
* 1. Sum: O(log N)
* 注意到 (0, (101100)] = ( (0), (100000) ] + ( (100000), (101000) ] + ( (101000), (101100) ]
* 最大操作数不超过数字的比特数, 显然为 O(log N)
* 2. Update: O(log N)
* 如果想修改 (101101), 只需要修改所有包含它的区间, 以下数字对应 ( i - L(i), i ] 会包含 (101101)
* (101101), (101110), (110000), (1000000), (10000000), (100000000) ...
* 但是这个数字不可能超过数组长度, 最多只有 O(log N)
* 3. Get: O(1)
* 拷贝一份原始数组使得获取复杂度变成 O(1)
*/
template<class T>
class BITree {
public:
// 1. This iterator can interact with algorithm 'lower_bound' in STL
// e.g. If you want to find the first element in the range [first, last) which does not compare less than 'val'
// You can use std::lower_bound(bt.begin(), bt.end(), val);
class sum_iterator;
using iterator = BITree<T>::sum_iterator;
private:
// 2. Data member
std::vector<T> m_tree; // Binary Indexed Tree
std::vector<T> m_arr; // Let the get complexity become O(1), at the cost of doubling the space
protected:
// 3. Only for the faster implementation of the constructor, it will not be called at other times
virtual void add_impl(size_t index, T val);
public:
// 4. Constructor
BITree() = default;
~BITree() = default;
BITree(const BITree &bt);
BITree(BITree &&bt);
BITree(const std::vector<T> &nums);
BITree(std::vector<T> &&nums);
// 5. All public interfaces: get, add, update, sum, resize
/**
* @brief return A[index]
* @note plz ensure 0 <= index < m_arr.size()
*/
virtual T get(size_t index) const;
/**
* @brief Let A[index] increment by val
* @note plz ensure 0 <= index < m_arr.size()
*/
virtual void add(size_t index, T val);
/**
* @brief Let A[index] update by val, i.e. increment by (val - now_value)
* @note plz ensure 0 <= index < m_arr.size()
*/
virtual void update(size_t index, T val);
/**
* @brief Get the sum of the interval [0, index)
* @note plz ensure 0 <= index < m_arr.size() + 1
*/
virtual T sum(size_t index) const;
/**
* @brief Get the sum of the interval [left, right)
* if left > right, return the opposite of the sum of the interval [right, left)
* @note plz ensure 0 <= left, right < m_arr.size() + 1
*/
virtual T sum(size_t left, size_t right) const;
/**
* @brief return m_arr.size()
*/
virtual size_t size() const;
/**
* @brief resize, if siz < now_size, keep only [0, siz)
*/
virtual void resize(size_t siz);
/**
* @brief return iterator point to first item
*/
virtual iterator begin();
/**
* @brief return iterator point to last item
*/
virtual iterator end();
// 6. Make it easier to use
/**
* @brief Get raw 'm_arr' for debug or other use
*/
virtual const std::vector<T> &get_arr() const;
// 7. For operator<< without the 'friend' keyword
virtual void print(std::ostream &os) const;
// 8. Iterator for search the sum, e.g. interact with lower_bound
struct sum_iterator : public std::iterator<
std::random_access_iterator_tag, // iterator_category
T, // value_type
size_t, // difference_type
size_t, // pointer
T> // reference
{
private:
BITree<T> *m_bt;
size_t m_idx;
using Self = sum_iterator;
public:
using iterator_category = std::random_access_iterator_tag;
using value_type = T;
using difference_type = size_t;
using pointer = size_t;
using reference = T;
public:
pointer get_idx() const {return m_idx;}
public:
sum_iterator(BITree<T> *bt, size_t idx) : m_bt(bt), m_idx(idx) {}
sum_iterator(const Self &other) : m_bt(other.m_bt), m_idx(other.m_idx) {}
sum_iterator(Self &&other) : m_bt(other.m_bt), m_idx(other.m_idx) {}
sum_iterator& operator=(const Self &other) {m_bt = other.m_bt; m_idx = other.m_idx; return *this;}
sum_iterator& operator=(Self &&other) {m_bt = other.m_bt; m_idx = other.m_idx; return *this;}
reference operator*() const {return m_bt->sum(m_idx);}
Self& operator++() {++ m_idx; return *this;}
Self operator++(int) {Self ret_val = *this; ++ (*this); return ret_val;}
Self& operator+=(difference_type n) {m_idx += n; return *this;}
Self operator+(difference_type n) {return Self(m_bt, m_idx + n);}
Self& operator--() {-- m_idx; return *this;}
Self operator--(int) {Self ret_val = *this; -- (*this); return ret_val;}
Self& operator-=(difference_type n) {m_idx -= n; return *this;}
Self operator-(difference_type n) {return Self(m_bt, m_idx - n);}
bool operator==(Self other) const {return m_idx == other.m_idx;}
bool operator!=(Self other) const {return !(*this == other);}
bool operator<(Self other) const {return m_idx < other.m_idx;}
bool operator<=(Self other) const {return m_idx <= other.m_idx;}
bool operator>(Self other) const {return m_idx > other.m_idx;}
bool operator>=(Self other) const {return m_idx >= other.m_idx;}
difference_type operator-(Self other) const {return m_idx - other.m_idx;}
};
};
template<class T>
std::ostream& operator<< (std::ostream& os, const BITree<T> &packet);
// because of template class connot be divided by '.h' and '.cpp'
#include "bitree.cpp"
#endif //INC_BINARYINDEXEDTREE_BITREE_H
bitree.cpp
懒得动了,直接把代码复制过来——
/**
* @file bitree.h
* @brief Binary Indexed Tree
* @author 间宫羽咲sama
* @note Template implementation of Binary Indexed Tree
*/
#ifndef INC_BINARYINDEXEDTREE_BITREE_CPP
#define INC_BINARYINDEXEDTREE_BITREE_CPP
#include "bitree.h"
template<class T>
BITree<T>::BITree(const BITree<T> &nums)
: m_tree(nums.m_tree)
, m_arr(nums.m_arr)
{
}
template<class T>
BITree<T>::BITree(BITree<T> &&nums)
: m_tree(std::move(nums.m_tree))
, m_arr(std::move(nums.m_arr))
{
}
template<class T>
BITree<T>::BITree(const std::vector<T> &nums)
: m_tree(nums.size() + 1)
, m_arr(nums)
{
for (size_t i = 0; i < m_arr.size(); ++ i)
add_impl(i, m_arr[i]);
}
template<class T>
BITree<T>::BITree(std::vector<T> &&nums)
: m_tree(nums.size() + 1)
, m_arr(std::move(nums))
{
for (size_t i = 0; i < m_arr.size(); ++ i)
add_impl(i, m_arr[i]);
}
template<class T> T
BITree<T>::get(size_t index) const {
return m_arr[index];
}
template<class T> void
BITree<T>::add_impl(size_t index, T val) {
for (++ index; index < m_tree.size(); index += index & (-index))
m_tree[index] += val;
}
template<class T> void
BITree<T>::add(size_t index, T val) {
m_arr[index] += val;
add_impl(index, val);
}
template<class T> void
BITree<T>::update(size_t index, T val) {
add(index, val - get(index));
}
template<class T> T
BITree<T>::sum(size_t index) const {
T ans = 0;
for (; index > 0; index &= index - 1)
ans += m_tree[index];
return ans;
}
template<class T> T
BITree<T>::sum(size_t left, size_t right) const {
return sum(right) - sum(left);
}
template<class T> const std::vector<T> &
BITree<T>::get_arr() const {
return m_arr;
}
template<class T> size_t
BITree<T>::size() const {
return m_arr.size();
}
template<class T> void
BITree<T>::resize(size_t siz) {
if (siz == m_arr.size())
return;
m_arr.resize(siz);
// erase all
m_tree.resize(0);
m_tree.resize(siz + 1);
for (size_t i = 0; i < m_arr.size(); ++ i)
add_impl(i, m_arr[i]);
}
template<class T> typename BITree<T>::iterator
BITree<T>::begin() {
return BITree<T>::iterator(this, 0);
}
template<class T> typename BITree<T>::iterator
BITree<T>::end() {
return BITree<T>::iterator(this, size());
}
template<class T> void
BITree<T>::print(std::ostream &os) const {
auto end_1 = m_arr.end() - 1;
os << "[";
for (auto it = m_arr.begin(); it != m_arr.end(); ++ it)
os << *it << (it != end_1 ? "," : "");
os << "]";
}
template<class T> std::ostream &
operator<< (std::ostream &os, const BITree<T> &packet) {
packet.print(os);
return os;
}
#endif //INC_BINARYINDEXEDTREE_BITREE_CPP
bitree_test.h
懒得动了,直接把代码复制过来——
#ifndef BINARYINDEXEDTREE_BITREE_TEST_H
#define BINARYINDEXEDTREE_BITREE_TEST_H
#include "bitree.h"
#include <algorithm>
void BITree_test1() {
std::cout << std::endl << "---------- BITree_test1 begin ----------" << std::endl;
std::cout << std::endl << "---------- Part-1: Construct ----------" << std::endl;
std::vector<double> vec1{2.71828, 3.14159, 1.14514, 1.19198, 0.57721};
size_t siz = vec1.size();
BITree<double> bt1(std::move(vec1));
// expect output: bt1 = [2.71828,3.14159,1.14514,1.19198,0.57721]
std::cout << "bt1 = " << bt1 << std::endl;
// expect output: bt1 = vec1.size() = 0
std::cout << "vec1.size() = " << vec1.size() << std::endl;
BITree<double> bt2(bt1);
// expect output: bt2 = [2.71828,3.14159,1.14514,1.19198,0.57721]
std::cout << "bt2 = " << bt2 << std::endl;
BITree<double> bt3(std::move(bt2));
// expect output: bt2 = []
std::cout << "bt2 = " << bt2 << std::endl;
// expect output: bt2 = [2.71828,3.14159,1.14514,1.19198,0.57721]
std::cout << "bt3 = " << bt3 << std::endl;
std::cout << std::endl << "---------- Part-2: Sum ----------" << std::endl;
// expect output: bt1.sum(0) = 0
// bt1.sum(1) = 2.71828
// bt1.sum(2) = 5.85987
// bt1.sum(3) = 7.00501
// bt1.sum(4) = 8.19699
// bt1.sum(5) = 8.7742
for (size_t i = 0; i <= siz; ++ i)
std::cout << "bt1.sum(" << i << ") = " << bt1.sum(i) << std::endl;
std::cout << std::endl << "---------- Part-3: Add ----------" << std::endl;
bt1.add(2, 1);
// expect output: bt1 = [2.71828,3.14159,2.14514,1.19198,0.57721]
std::cout << "bt1 = " << bt1 << std::endl;
// expect output: bt1.sum(0) = 0
// bt1.sum(1) = 2.71828
// bt1.sum(2) = 5.85987
// bt1.sum(3) = 8.00501
// bt1.sum(4) = 9.19699
// bt1.sum(5) = 9.7742
for (size_t i = 0; i <= siz; ++ i)
std::cout << "bt1.sum(" << i << ") = " << bt1.sum(i) << std::endl;
std::cout << std::endl << "---------- Part-4: ReSize ----------" << std::endl;
bt1.resize(siz - 1);
// expect output: bt1 = [2.71828,3.14159,2.14514,1.19198]
std::cout << "bt1 = " << bt1 << std::endl;
// expect output: bt1.sum(0) = 0
// bt1.sum(1) = 2.71828
// bt1.sum(2) = 5.85987
// bt1.sum(3) = 8.00501
// bt1.sum(4) = 9.19699
for (size_t i = 0; i <= siz - 1; ++ i)
std::cout << "bt1.sum(" << i << ") = " << bt1.sum(i) << std::endl;
bt3.resize(siz + 1);
// expect output: bt3 = [2.71828,3.14159,2.14514,1.19198,0.57721, 0]
std::cout << "bt3 = " << bt3 << std::endl;
// expect output: bt3.sum(0) = 0
// bt3.sum(1) = 2.71828
// bt3.sum(2) = 5.85987
// bt3.sum(3) = 7.00501
// bt3.sum(4) = 8.19699
// bt3.sum(5) = 8.7742
// bt3.sum(6) = 8.7742
for (size_t i = 0; i <= siz + 1; ++ i)
std::cout << "bt3.sum(" << i << ") = " << bt3.sum(i) << std::endl;
std::cout << std::endl << "---------- Part-5: Iterator ----------" << std::endl;
BITree<double>::iterator it3 = std::lower_bound(bt1.begin(), bt1.end(), 5);
// expect output: std::lower_bound(bt1.begin(), bt1.end(), 5) = 5.85987
std::cout << "std::lower_bound(bt1.begin(), bt1.end(), 5) = " << *it3 << std::endl;
++ it3;
// expect output: std::lower_bound(bt1.begin(), bt1.end(), 5) = 8.00501
std::cout << "*++it = " << *it3 << std::endl;
it3 = std::lower_bound(bt1.begin(), bt1.end(), 5.85988);
// expect output: std::lower_bound(bt1.begin(), bt1.end(), 5.85988) = 8.00501
std::cout << "std::lower_bound(bt1.begin(), bt1.end(), 5.85988) = " << *it3 << std::endl;
// expect output: bt1.sum(0) = 0
// bt1.sum(1) = 2.71828
// bt1.sum(2) = 5.85987
// bt1.sum(3) = 8.00501
// bt1.sum(4) = 9.19699
for (auto it = bt1.begin(); it <= bt1.end(); ++ it)
std::cout << "bt1.sum(" << it.get_idx() << ") = " << *it << std::endl;
// expect output: bt1.sum... = 0
// bt1.sum... = 2.71828
// bt1.sum... = 5.85987
// bt1.sum... = 8.00501
for (auto prefix_sum : bt1)
std::cout << "bt1.sum... = " << prefix_sum << std::endl;
}
#endif //BINARYINDEXEDTREE_BITREE_TEST_H
main.cpp
懒得动了,直接把代码复制过来——
#include <iostream>
#include "bitree.h"
#include "bitree_test.h"
int main() {
BITree_test1();
return 0;
}
总结
总的来说,代码本身其实是次要的,主要是靠这个机会练习了下C++代码的写作规范,比如拷贝构造、移动构造、默认构造怎么写,怎么写一个符合STL库规范的迭代器。
在写代码的过程中也踩了一些老生常谈的坑,比如模板函数不能拆成 .h
文件和 .cpp
文件,因为模板的声明和实现是一体的,只能在 .h
文件末尾 include
对应的 .cpp
文件才行。写迭代器的过程也很头疼,不过写完一次就明白多了。总之这个工程本身其实倒没啥,主要是靠这个机会实操了一下这些东西。