您的位置:首页 > 其它

线段树 原理及模板

2018-02-21 21:58 405 查看
代码模板:

https://github.com/rsy56640/rsy_little_lib/tree/master/library_for_algorithm/SegmentTree

线段树原理:

对于一个给定程度的区间 ll,我们递归地定义:

区间 ll 的线段树为:

其根节点表示区间 ll;将其区间分为近似2段 l1,l2l1,l2,左子树为区间 l1l1 的线段树,右子树为区间 l2l2 的线段树。

并且提供一个二元代数运算 oo,对于值类型构成一个幺半群。

线段树只提供两个方法:

1. 区间查询 O(logn)O(logn)

2. 单点修改 O(logn)O(logn)

SegmentTreeType.h

#pragma once
#ifndef _SEGMENTTREETYPE_H
#include <functional>
#include <memory>
#include <type_traits>

template<class _Ty>
class SegmentTreeNode;

template<class _Ty>
struct SegmentTreeType
{

using value_type = typename std::remove_reference<_Ty>::type;
using SegmentTreeNode_ptr = typename std::tr1::shared_ptr <SegmentTreeNode<_Ty> >;
using Func = typename std::tr1::function<_Ty(const _Ty&, const _Ty&)>;

};

#endif // !_SEGMENTTREETYPE_H


SegmentTreeNode.h

#pragma once
#ifndef _SEGMENTTREENODE_H
#include "SegmentTreeType.h"

//SegmentTreeNode Template
template<class _Ty> class SegmentTreeNode :public SegmentTreeType<_Ty>
{

public:

SegmentTreeNode(int start, int end, _Ty val)
:_start(start), _end(end), _val(val), _left(nullptr), _right(nullptr) {}

int start() const noexcept
{
return _start;
}

int end() const noexcept
{
return _end;
}

//value_type
_Ty val() const noexcept
{
return _val;
}

//const reference_type
const _Ty& value() const noexcept
{
return _val;
}

void setValue(_Ty&& value)
{
_val = _STD forward<_Ty>(value);
}

SegmentTreeNode_ptr& left() noexcept
{
return _left;
}

SegmentTreeNode_ptr& right() noexcept
{
return _right;
}

SegmentTreeNode(const SegmentTreeNode&) = delete;

SegmentTreeNode& operator=(const SegmentTreeNode&) = delete;

SegmentTreeNode(SegmentTreeNode&&) = delete;

SegmentTreeNode& operator=(SegmentTreeNode&&) = delete;

~SegmentTreeNode() = default;

private:

int _start, _end;
_Ty _val;
SegmentTreeNode_ptr _left, _right;

};

#endif // !_SEGMENTTREENODE_H


SegmentTreeException.h

#pragma once
#ifndef _SEGMENTTREEEXCEPTION_H
#include <exception>
#include <string>

template<class _Ty> class SegmentTreeException :public exception
{

public:
SegmentTreeException(std::string msg)
:_msg(msg) {}

const char* what() const noexcept
{
return "SegmentTree Exception";
}

private:

std::string _msg;

friend
std::ostream& operator<<(std::ostream& os, const SegmentTreeException<_Ty>& e)
{
os << e._msg;
return os;
}

};

#endif // !_SEGMENTTREEEXCEPTION_H


SegmentTreeImpl.h

#pragma once
#ifndef _SEGMENTTREEIMPL_H
#include "SegmentTreeNode.h"
#include "SegmentTreeException.h"
#include <vector>
using std::vector;

//SegmentTreeImpl Template
template<class _Ty> class SegmentTreeImpl :public SegmentTreeType<_Ty>
{

public:

//customized function constructor
SegmentTreeImpl(const vector<_Ty>& Vec, Func func, _Ty Identity_Element)
: _root(nullptr), _Func(func), _Identity_Element(Identity_Element), _checked(false)
{
//异常检查
if (Vec.empty())
throw SegmentTreeException<_Ty>("The Segment is empty!!");

//初始化线段树,空间复杂度O(n),时间复杂度O(n)
_root = build(0, Vec.size() - 1, Vec);
_checked = true;

}

_Ty query(int start, int end) const
{

if (!_checked)
throw SegmentTreeException<_Ty>("The Segment is empty!!");

if (start > end)
throw SegmentTreeException<_Ty>("The querying range is invalid!!");

return doQuery(_root, start, end);

}

void modify(int index, _Ty&& value)
{

if (!_checked)
throw SegmentTreeException<_Ty>("The Segment is empty!!");

if (index<_root->start() || index>_root->end())
throw SegmentTreeException<_Ty>("The Index is invalid!!");

doModify(_root, index, _STD forward<_Ty>(value));

}

SegmentTreeImpl(const SegmentTreeImpl&) = delete;

SegmentTreeImpl& operator=(const SegmentTreeImpl&) = delete;

SegmentTreeImpl(SegmentTreeImpl&&) = delete;

SegmentTreeImpl& operator=(SegmentTreeImpl&&) = delete;

~SegmentTreeImpl() = default;

protected:

SegmentTreeNode_ptr _root;

//_Func是一个_Ty上的二元代数运算符,满足结合律,有幺元,_Ty对_Func构成一个幺半群
Func _Func;

//幺元
const _Ty _Identity_Element;

//check out if SegmentTree exists
bool _checked;

private:

//SegmentTree Initialization
SegmentTreeNode_ptr build(int start, int end, const vector<_Ty>& Vec)
{

//leaf node
if (start == end)
return make_shared<SegmentTreeNode<_Ty> >(start, end, Vec[start]);

//internal node (non-leaf)
int mid = (start + end) / 2;

//construct this node with initial val(_Identity_Element)
SegmentTreeNode_ptr node =
make_shared<SegmentTreeNode<_Ty> >(start, end, _Identity_Element);

//construct left and right subTree (recursion)
node->left() = (build(start, mid, Vec));
node->right() = (build(mid + 1, end, Vec));

//set value
node->setValue(
_STD forward<_Ty>
(_Func(node->left()->value(), node->right()->value())));

return node;

}

//
_Ty doQuery(SegmentTreeNode_ptr root, int start, int end) const
{

//no segment union
if (start > root->end() || end < root->start())
return _Identity_Element;

//querying segment includes root segment
if (start <= root->start() && root->end() <= end)
return root->val();

//partially coincide
return _Func(doQuery(root->left(), start, end), doQuery(root->right(), start, end));

}

//
void doModify(SegmentTreeNode_ptr root, int index, _Ty&& value)
{

//leaf node found
if (root->start() == root->end() && root->start() == index)
{
root->setValue(_STD forward<_Ty>(value));
return;
}

//not found
int mid = (root->start() + root->end()) / 2;

//left subTree
if (index <= mid)
{
doModify(root->left(), index, _STD forward<_Ty>(value));
root->setValue(
//_STD forward<_Ty>
(_Func(root->left()->value(), root->right()->value())));
}

//right subTree
else
{
doModify(root->right(), index, _STD forward<_Ty>(value));
root->setValue(
//_STD forward<_Ty>
(_Func(root->left()->value(), root->right()->value())));
}

}

};

#endif // !_SEGMENTTREEIMPL_H


SegmentTree.h

#pragma once
#ifndef _SEGMENTTREE_H
#include "SegmentTreeImpl.h"

//SegmentTree Template
template<class _Ty> class SegmentTree :public SegmentTreeType<_Ty>
{

using PImpl = typename std::tr1::shared_ptr<SegmentTreeImpl<_Ty> >;

public:

SegmentTree(const vector<_Ty>& Vec, Func func, _Ty Identity_Element)
:_pImpl(make_shared<SegmentTreeImpl<_Ty> >(Vec, func, Identity_Element)) {}

//查询操作,时间复杂度O(logn)
_Ty query(int start, int end) const
{
return _pImpl->query(start, end);
}

//修改操作,时间复杂度O(logn)
void modify(int index, _Ty&& value)
{
_pImpl->modify(index, _STD forward<_Ty>(value));
}

private:

PImpl _pImpl;

};

#endif // !_SEGMENTTREE_H


main.cpp

#include "SegmentTree.h"
#include <iostream>

using namespace std;

int foo(int a, int b)
{
return (a > b) ? a : b;
}

int main()
{

vector<int> v = { 1,2,7,8,5 };

try
{
SegmentTree<int> h(v, foo, 0);

int a = h.query(0, 2);

h.modify(0, 4);

int b = h.query(0, 1);

h.modify(2, 11);

int c = h.query(2, 3);

}
catch (SegmentTreeException<int>& e)
{
cout << e << endl;
}

system("pause");
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: