您的位置:首页 > 其它

RSA算法的一种实现方式!

2013-05-29 13:09 218 查看
在开始解释这个算法之前,我先描述一下这里面大数的存储方式!存储结构如下图:



在这里,一个大数总共占maxLengthge(一般取512)个单元,由于每一个单元都是usigned int型的,因此每一个单元都有32位(bit),这里的大数大家可以看做是一个超大型的int型数据,总共有512*32(一般的int只有32位),最高位为符号位,1代表为负数,0代表正数.

还有,我们很难将这个数用具体的十进制表示出来,因为2^(512*32)是一个很大的数,但是这样的一个结构十分有利于我们做具体的位操作!

下面我们看一看头文件:


#pragma once
#include <cstring>
#include <string>
#include <algorithm>
#include <assert.h>

using namespace std;

class BigInteger
{
	typedef unsigned char byte;
public:
	BigInteger(void);
	BigInteger(__int64 value);
	BigInteger(unsigned __int64 value);
	BigInteger(const BigInteger &bi);
	BigInteger(string value, int radix);
	BigInteger(byte inData[], int inLen);
	BigInteger(unsigned int inData[], int inLen);
	BigInteger operator =(const BigInteger &bi2);
	BigInteger operator +(const BigInteger &bi2);
	BigInteger operator -();
	BigInteger modPow(BigInteger exp, BigInteger n);
	int bitCount();
	BigInteger BarrettReduction(BigInteger x, BigInteger n, BigInteger constant);
	bool operator >=(BigInteger bi2)
	{
		return ((*this) == bi2 || (*this) > bi2);
	}
	bool operator >(BigInteger bi2);
	bool operator ==(BigInteger bi2);
	BigInteger operator %(BigInteger bi2);
	void singleByteDivide(BigInteger &bi1, BigInteger &bi2,
		BigInteger &outQuotient, BigInteger &outRemainder);
	void multiByteDivide(BigInteger &bi1, BigInteger &bi2,
		BigInteger &outQuotient, BigInteger &outRemainder);
	int shiftRight(unsigned int buffer[], int bufLen, int shiftVal);
	BigInteger operator <<(int shiftVal);
	int shiftLeft(unsigned int buffer[], int bufLen, int shiftVal);
	bool operator <(BigInteger bi2);
	BigInteger operator +=(BigInteger bi2);
	BigInteger operator /(BigInteger bi2);
	BigInteger operator -=(BigInteger bi2);
	BigInteger operator -(BigInteger bi2);
	string DecToHex(unsigned int value, string format);
	string ToHexString();
public:
	~BigInteger(void);
public:
	BigInteger operator *(BigInteger bi2);
private:

public:
	static const int primesBelow2000[];
		// primes smaller than 2000 to test the generated prime number
	int dataLength;
		// number of actual chars used
private:
	static const int maxLength;
		// maximum length of the BigInteger in uint (4 bytes)
		// change this to suit the required level of precision.
	unsigned int *data;
		// stores bytes from the Big Integer
};


下面是重点,关于大数类的实现:

// stdafx.h : include file for standard system include files,
// or project specific include files that are used frequently, but
// are changed infrequently
//

#pragma once

#define WIN32_LEAN_AND_MEAN		// Exclude rarely-used stuff from Windows headers
#include <stdio.h>
#include <tchar.h>

// TODO: reference additional headers your program requires here


#include "StdAfx.h"
#include "BigInteger.h"

const int BigInteger::maxLength = 512;   //大数所占的最大长度为512位

const int BigInteger::primesBelow2000[] = {   //素数表
	2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,
	101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199,
	211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293,
	307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397,
	401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499,
	503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599,
	601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691,
	701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797,
	809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887,
	907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997,
	1009, 1013, 1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 1063, 1069, 1087, 1091, 1093, 1097,
	1103, 1109, 1117, 1123, 1129, 1151, 1153, 1163, 1171, 1181, 1187, 1193,
	1201, 1213, 1217, 1223, 1229, 1231, 1237, 1249, 1259, 1277, 1279, 1283, 1289, 1291, 1297,
	1301, 1303, 1307, 1319, 1321, 1327, 1361, 1367, 1373, 1381, 1399,
	1409, 1423, 1427, 1429, 1433, 1439, 1447, 1451, 1453, 1459, 1471, 1481, 1483, 1487, 1489, 1493, 1499,
	1511, 1523, 1531, 1543, 1549, 1553, 1559, 1567, 1571, 1579, 1583, 1597,
	1601, 1607, 1609, 1613, 1619, 1621, 1627, 1637, 1657, 1663, 1667, 1669, 1693, 1697, 1699,
	1709, 1721, 1723, 1733, 1741, 1747, 1753, 1759, 1777, 1783, 1787, 1789,
	1801, 1811, 1823, 1831, 1847, 1861, 1867, 1871, 1873, 1877, 1879, 1889,
	1901, 1907, 1913, 1931, 1933, 1949, 1951, 1973, 1979, 1987, 1993, 1997, 1999 };

BigInteger::BigInteger(void)
: dataLength(0), data(0)
{
	data = new unsigned int[maxLength];  //创建一个数组
	memset(data, 0, maxLength * sizeof(unsigned int));
	dataLength = 1;  //新数组表示0,数据长度为1
}

BigInteger::BigInteger(__int64 value)  //用一个64位的数来初始化BigInteger
{
	data = new unsigned int[maxLength];
	memset(data, 0, maxLength * sizeof(unsigned int));
	__int64 tempVal = value;

	// copy bytes from __int64 to BigInteger without any assumption of
	// the length of the __int64 datatype

	dataLength = 0;
	while (value != 0 && dataLength < maxLength)
	{
		data[dataLength] = (unsigned int)(value & 0xFFFFFFFF); //取数值的低32位
		value = value >> 32; //value表示进位
		dataLength++; //数据长度+1
	}

	if (tempVal > 0)         // overflow check for +ve value 溢出检查
	{
		if (value != 0 || (data[maxLength - 1] & 0x80000000) != 0)
			assert(false);

/*
assert宏的原型定义在<assert.h>中,其作用是如果它的条件返回错误,则终止程序执行,原型定义:
#include <assert.h>
void assert( int expression );
assert的作用是现计算表达式 expression ,如果其值为假(即为0),那么它先向stderr打印一条出错信息
,然后通过调用 abort 来终止程序运行。
*/

	}
	else if (tempVal < 0)    // underflow check for -ve value
	{
		if (value != -1 || (data[dataLength - 1] & 0x80000000) == 0)
			assert(false);
	}

	if (dataLength == 0)
		dataLength = 1;   //0也占一个位置
}

BigInteger::BigInteger(unsigned __int64 value)  //用一个无符号的64位的数字来初始化BigInteger
{
	data = new unsigned int[maxLength];
	memset(data, 0, maxLength * sizeof(unsigned int));
	// copy bytes from u__int64 to BigInteger without any assumption of
	// the length of the u__int64 datatype

	dataLength = 0;
	while (value != 0 && dataLength < maxLength)
	{
		data[dataLength] = (unsigned int)(value & 0xFFFFFFFF);
		value >>= 32;
		dataLength++;
	}

	if (value != 0 || (data[maxLength - 1] & 0x80000000) != 0)
		assert(false);

	if (dataLength == 0)
		dataLength = 1;
}

BigInteger::BigInteger(const BigInteger &bi)
{
	data = new unsigned int[maxLength];
	memset(data, 0, maxLength * sizeof(unsigned int));

	dataLength = bi.dataLength;

	for (int i = 0; i < dataLength; i++)
		data[i] = bi.data[i];
}

BigInteger::~BigInteger(void)
{
	if (data != NULL)
	{
		delete []data;
	}
}

BigInteger::BigInteger(string value, int radix) //用字符串value来初始化BigInteger
{
	BigInteger multiplier((__int64)1); //乘数为1
	BigInteger result;
	transform(value.begin(), value.end(), value.begin(), toupper); //将value中的小写化为大写

	int limit = 0;

	if (value[0] == '-')
		limit = 1;  //表明输入为负

	for (int i = value.size() - 1; i >= limit; i--)
	{
		int posVal = (int)value[i]; //将字符转化数字(ASCII码对应的数字)

		if (posVal >= '0' && posVal <= '9')
			posVal -= '0';
		else if (posVal >= 'A' && posVal <= 'Z') //这里是出于16进制的考虑
			posVal = (posVal - 'A') + 10;
		else
			posVal = 9999999;       // arbitrary large 

		if (posVal >= radix) //这里支持的radix(进制)有限,十进制以及十六进制
		{
			assert(false);
		}
		else
		{
			if (value[0] == '-')
				posVal = -posVal;

			result = result + (multiplier * BigInteger((__int64)posVal));

			if ((i - 1) >= limit)
				multiplier = multiplier * BigInteger((__int64)radix);
		}
	}

	if (value[0] == '-')     // negative values
	{
		if ((result.data[maxLength - 1] & 0x80000000) == 0)
			assert(false);
	}
	else    // positive values
	{
		if ((result.data[maxLength - 1] & 0x80000000) != 0)
			assert(false);
	}

	data = new unsigned int[maxLength];
	memset(data, 0, maxLength * sizeof(unsigned int));
	for (int i = 0; i < result.dataLength; i++)
		data[i] = result.data[i];

	dataLength = result.dataLength;
}

BigInteger::BigInteger(byte inData[], int inLen)  //用一个unsigned char型的数组来初始化BigInteger
{
	dataLength = inLen >> 2;

	int leftOver = inLen & 0x3; //取低两位的值
	if (leftOver != 0)         // length not multiples of 4
		dataLength++;

	if (dataLength > maxLength)
		assert(false);

	data = new unsigned int[maxLength];
	memset(data, 0, maxLength * sizeof(unsigned int));

	for (int i = inLen - 1, j = 0; i >= 3; i -= 4, j++)
	{
		data[j] = (unsigned int)((inData[i - 3] << 24) + (inData[i - 2] << 16) +
			(inData[i - 1] << 8) + inData[i]);
//我们知道:一个unsigned int占32位,而一个unsigned char只占8位,因此四个unsigned char才能组成一个unsigned int
//因此取inData[i - 3]为前32-25位,inData[i - 2]为前24-17~~~
//i%4=0 or 1 or 2 or 3 余0表示恰好表示完
	}

	if (leftOver == 1)
		data[dataLength - 1] = (unsigned int)inData[0];
	else if (leftOver == 2)
		data[dataLength - 1] = (unsigned int)((inData[0] << 8) + inData[1]);
	else if (leftOver == 3)
		data[dataLength - 1] = (unsigned int)((inData[0] << 16) + (inData[1] << 8) + inData[2]);

	while (dataLength > 1 && data[dataLength - 1] == 0)
		dataLength--;   //确定数组数值确切占的位置数
}

BigInteger::BigInteger(unsigned int inData[], int inLen)
{
	dataLength = inLen;

	if (dataLength > maxLength)
		assert(false);

	data = new unsigned int[maxLength];
	memset(data, 0, maxLength * sizeof(maxLength));

	for (int i = dataLength - 1, j = 0; i >= 0; i--, j++)
		data[j] = inData[i];

	while (dataLength > 1 && data[dataLength - 1] == 0)
		dataLength--;
}

BigInteger BigInteger::operator *(BigInteger bi2) //乘法的重载
{
	BigInteger bi1(*this);
	int lastPos = maxLength - 1;
	bool bi1Neg = false, bi2Neg = false;

	// take the absolute value of the inputs
	try  //取两个数的绝对值
	{
		if ((this->data[lastPos] & 0x80000000) != 0)     // bi1 negative
		{
			bi1Neg = true; 
			bi1 = -bi1;
		}
		if ((bi2.data[lastPos] & 0x80000000) != 0)     // bi2 negative
		{
			bi2Neg = true; bi2 = -bi2;
		}
	}
	catch (...) { }

	BigInteger result;

	// multiply the absolute values
	try
	{
		for (int i = 0; i < bi1.dataLength; i++)
		{
			if (bi1.data[i] == 0) continue;

			unsigned __int64 mcarry = 0; //向前的进位
			for (int j = 0, k = i; j < bi2.dataLength; j++, k++)
			{
				// k = i + j
				unsigned __int64 val = ((unsigned __int64)bi1.data[i] * (unsigned __int64)bi2.data[j]) +
					(unsigned __int64)result.data[k] + mcarry;

				result.data[k] = (unsigned __int64)(val & 0xFFFFFFFF); //取低32位
				mcarry = (val >> 32); //向前进位
			}

			if (mcarry != 0)
				result.data[i + bi2.dataLength] = (unsigned int)mcarry;
		}
	}
	catch (...)
	{
		assert(false);
	}

	result.dataLength = bi1.dataLength + bi2.dataLength;
	if (result.dataLength > maxLength)
		result.dataLength = maxLength;

	while (result.dataLength > 1 && result.data[result.dataLength - 1] == 0)
		result.dataLength--;

	// overflow check (result is -ve)
	if ((result.data[lastPos] & 0x80000000) != 0)
	{
		if (bi1Neg != bi2Neg && result.data[lastPos] == 0x80000000)    // different sign
		{
			// handle the special case where multiplication produces
			// a max negative number in 2's complement.

			if (result.dataLength == 1)
				return result;
			else
			{
				bool isMaxNeg = true;
				for (int i = 0; i < result.dataLength - 1 && isMaxNeg; i++)
				{
					if (result.data[i] != 0)
						isMaxNeg = false;
				}

				if (isMaxNeg)
					return result;
			}
		}

		assert(false);
	}

	// if input has different signs, then result is -ve
	if (bi1Neg != bi2Neg)
		return -result;

	return result;
}
BigInteger BigInteger::operator =(const BigInteger &bi2)
{
	if (&bi2 == this)
	{
		return *this;
	}
	if (data != NULL)
	{
		delete []data;
		data = NULL;
	}
	data = new unsigned int[maxLength];
	memset(data, 0, maxLength * sizeof(unsigned int));

	dataLength = bi2.dataLength;

	for (int i = 0; i < dataLength; i++)
		data[i] = bi2.data[i];
	return *this;
}

BigInteger BigInteger::operator +(const BigInteger &bi2)
{
	BigInteger result;

	result.dataLength = (this->dataLength > bi2.dataLength) ? this->dataLength : bi2.dataLength;

	__int64 carry = 0;
	for (int i = 0; i < result.dataLength; i++)
	{ //从低位开始,一步一步地加
		__int64 sum = (__int64)this->data[i] + (__int64)bi2.data[i] + carry;
		carry = sum >> 32; //进位
		result.data[i] = (unsigned int)(sum & 0xFFFFFFFF); //取低位
	}

	if (carry != 0 && result.dataLength < maxLength)
	{
		result.data[result.dataLength] = (unsigned int)(carry);
		result.dataLength++;
	}

	while (result.dataLength > 1 && result.data[result.dataLength - 1] == 0)
		result.dataLength--;

	// overflow check
	int lastPos = maxLength - 1;
	if ((this->data[lastPos] & 0x80000000) == (bi2.data[lastPos] & 0x80000000) &&
		(result.data[lastPos] & 0x80000000) != (this->data[lastPos] & 0x80000000))
	{
		assert(false);
	}

	return result;
}

BigInteger BigInteger::operator -() //"-"号的重载
{
	// handle neg of zero separately since it'll cause an overflow
	// if we proceed.

	if (this->dataLength == 1 && this->data[0] == 0)
		return *this;  //-0的仍为0

	BigInteger result(*this);

	// 1's complement
	for (int i = 0; i < maxLength; i++)
		result.data[i] = (unsigned int)(~(this->data[i])); //按位取反

	// add one to result of 1's complement
	__int64 val, carry = 1;
	int index = 0;

	while (carry != 0 && index < maxLength) //然后加1  
	{
		val = (__int64)(result.data[index]);
		val++;

		result.data[index] = (unsigned int)(val & 0xFFFFFFFF);
		carry = val >> 32;

		index++;
	}

	if ((this->data[maxLength - 1] & 0x80000000) == (result.data[maxLength - 1] & 0x80000000))
		// throw (new ArithmeticException("Overflow in negation.\n"));

		result.dataLength = maxLength;

	while (result.dataLength > 1 && result.data[result.dataLength - 1] == 0)
		result.dataLength--;
	return result;
}

BigInteger BigInteger::modPow(BigInteger exp, BigInteger n)
{
	if ((exp.data[maxLength - 1] & 0x80000000) != 0) //表明为负数
	{
		// throw (new ArithmeticException("Positive exponents only."));
		return BigInteger((__int64)0);
	}

	BigInteger resultNum((__int64)1);
	BigInteger tempNum;
	bool thisNegative = false;

	if ((this->data[maxLength - 1] & 0x80000000) != 0)   // negative this
	{
		tempNum = -(*this) % n;
		thisNegative = true;
	}
	else
		tempNum = (*this) % n;  // ensures (tempNum * tempNum) < b^(2k)

	if ((n.data[maxLength - 1] & 0x80000000) != 0)   // negative n
		n = -n;

	// calculate constant = b^(2k) / n
	//这里的b即base,这里指的是2^32

	BigInteger constant;

	int i = n.dataLength << 1;
	constant.data[i] = 0x00000001;
	constant.dataLength = i + 1;

	constant = constant / n;  //constant主要用来计算后面的取模
	int totalBits = exp.bitCount();
	int count = 0;

	// perform squaring and multiply exponentiation
	//平方乘法算法
	for (int pos = 0; pos < exp.dataLength; pos++)
	{
		unsigned int mask = 0x01;
		//Console.WriteLine("pos = " + pos);

		for (int index = 0; index < 32; index++)
		{
			if ((exp.data[pos] & mask) != 0) //某一位不为0时
				resultNum = BarrettReduction(resultNum * tempNum, n, constant);

			mask <<= 1;

			tempNum = BarrettReduction(tempNum * tempNum, n, constant);

			if (tempNum.dataLength == 1 && tempNum.data[0] == 1)
			{
				if (thisNegative && (exp.data[0] & 0x1) != 0)    //odd exp
					return -resultNum;
				return resultNum;
			}
			count++;
			if (count == totalBits)
				break;
		}
	}

	if (thisNegative && (exp.data[0] & 0x1) != 0)    //odd exp
		return -resultNum;

	return resultNum;
}

int BigInteger::bitCount() //计大数实际所占的位数(bit)
{
	while (dataLength > 1 && data[dataLength - 1] == 0)
		dataLength--;

	unsigned int value = data[dataLength - 1];
	unsigned int mask = 0x80000000;
	int bits = 32;

	while (bits > 0 && (value & mask) == 0)
	{
		bits--;
		mask >>= 1;
	}
	bits += ((dataLength - 1) << 5);//左移5位,相当于乘以32,即2^5

	return bits;
}

BigInteger BigInteger::BarrettReduction(BigInteger x, BigInteger n, BigInteger constant)
{
//算法,Baeert Reduction算法,在计算大规模的除法运算时很有优势
//原理如下
//Z mod N=Z-[Z/N]*N=Z-{[Z/b^(n-1)]*[b^2n/N]/b^(n+1)}*N=Z-q*N
//q=[Z/b^(n-1)]*[b^2n/N]/b^(n+1)
//其中,[]表示取整运算,A^B表示A的B次幂

	int k = n.dataLength,
		kPlusOne = k + 1,
		kMinusOne = k - 1;

	BigInteger q1;

	// q1 = x / b^(k-1)
	for (int i = kMinusOne, j = 0; i < x.dataLength; i++, j++)
		q1.data[j] = x.data[i];
	q1.dataLength = x.dataLength - kMinusOne;
	if (q1.dataLength <= 0)
		q1.dataLength = 1;

	BigInteger q2 = q1 * constant;//q2=[x/b^(k-1)]*b^2k/n;
	BigInteger q3;

	// q3 = q2 / b^(k+1)
	for (int i = kPlusOne, j = 0; i < q2.dataLength; i++, j++)
		q3.data[j] = q2.data[i];
	q3.dataLength = q2.dataLength - kPlusOne;
	if (q3.dataLength <= 0)
		q3.dataLength = 1;

	// r1 = x mod b^(k+1)
	// i.e. keep the lowest (k+1) words
	BigInteger r1;
	int lengthToCopy = (x.dataLength > kPlusOne) ? kPlusOne : x.dataLength;
	for (int i = 0; i < lengthToCopy; i++)
		r1.data[i] = x.data[i];
	r1.dataLength = lengthToCopy;

	// r2 = (q3 * n) mod b^(k+1)
	// partial multiplication of q3 and n

	BigInteger r2;
	for (int i = 0; i < q3.dataLength; i++)
	{
		if (q3.data[i] == 0) continue;

		unsigned __int64 mcarry = 0;
		int t = i;
		for (int j = 0; j < n.dataLength && t < kPlusOne; j++, t++)
		{
			// t = i + j
			unsigned __int64 val = ((unsigned __int64)q3.data[i] * (unsigned __int64)n.data[j]) +
				(unsigned __int64)r2.data[t] + mcarry;

			r2.data[t] = (unsigned int)(val & 0xFFFFFFFF);
			mcarry = (val >> 32);
		}

		if (t < kPlusOne)
			r2.data[t] = (unsigned int)mcarry;
	}
	r2.dataLength = kPlusOne;
	while (r2.dataLength > 1 && r2.data[r2.dataLength - 1] == 0)
		r2.dataLength--;

	r1 -= r2;
	if ((r1.data[maxLength - 1] & 0x80000000) != 0)        // negative
	{
		BigInteger val;
		val.data[kPlusOne] = 0x00000001;
		val.dataLength = kPlusOne + 1;
		r1 += val;
	}

	while (r1 >= n)
		r1 -= n;

	return r1;
}

bool BigInteger::operator >(BigInteger bi2)  //'>'的重载
{
	int pos = maxLength - 1;
	BigInteger bi1(*this);

	// bi1 is negative, bi2 is positive
	if ((bi1.data[pos] & 0x80000000) != 0 && (bi2.data[pos] & 0x80000000) == 0)
		return false;

	// bi1 is positive, bi2 is negative
	else if ((bi1.data[pos] & 0x80000000) == 0 && (bi2.data[pos] & 0x80000000) != 0)
		return true;

	// same sign
	int len = (bi1.dataLength > bi2.dataLength) ? bi1.dataLength : bi2.dataLength;
	for (pos = len - 1; pos >= 0 && bi1.data[pos] == bi2.data[pos]; pos--) ;

	if (pos >= 0)
	{
		if (bi1.data[pos] > bi2.data[pos])
			return true;
		return false;
	}
	return false;
}

bool BigInteger::operator ==(BigInteger bi2)
{
	if (this->dataLength != bi2.dataLength)
		return false;

	for (int i = 0; i < this->dataLength; i++)
	{
		if (this->data[i] != bi2.data[i])
			return false;
	}
	return true;
}

BigInteger BigInteger::operator %(BigInteger bi2)
{
	BigInteger bi1(*this);
	BigInteger quotient;
	BigInteger remainder(bi1);

	int lastPos = maxLength - 1;
	bool dividendNeg = false;

	if ((bi1.data[lastPos] & 0x80000000) != 0)     // bi1 negative
	{
		bi1 = -bi1;
		dividendNeg = true;
	}
	if ((bi2.data[lastPos] & 0x80000000) != 0)     // bi2 negative
		bi2 = -bi2;

	if (bi1 < bi2)
	{
		return remainder;
	}

	else
	{
		if (bi2.dataLength == 1)
			singleByteDivide(bi1, bi2, quotient, remainder); //bi2只占一位时,用singleByteDivide更快
		else
			multiByteDivide(bi1, bi2, quotient, remainder);  //bi2占多位时,用multiByteDivide更快

		if (dividendNeg)
			return -remainder;

		return remainder;
	}
}

void BigInteger::singleByteDivide(BigInteger &bi1, BigInteger &bi2,
					  BigInteger &outQuotient, BigInteger &outRemainder)
{//outQuotient商,outRemainder余数
	unsigned int result[maxLength];//用来存储结果
	memset(result, 0, sizeof(unsigned int) * maxLength);
	int resultPos = 0;

	// copy dividend to reminder
	for (int i = 0; i < maxLength; i++)
		outRemainder.data[i] = bi1.data[i];//将bi1复制至outRemainder
	outRemainder.dataLength = bi1.dataLength;

	while (outRemainder.dataLength > 1 && outRemainder.data[outRemainder.dataLength - 1] == 0)
		outRemainder.dataLength--;

	unsigned __int64 divisor = (unsigned __int64)bi2.data[0]; 
	int pos = outRemainder.dataLength - 1;
	unsigned __int64 dividend = (unsigned __int64)outRemainder.data[pos];

	//Console.WriteLine("divisor = " + divisor + " dividend = " + dividend);
	//Console.WriteLine("divisor = " + bi2 + "\ndividend = " + bi1);

	if (dividend >= divisor) //被除数>除数
	{
		unsigned __int64 quotient = dividend / divisor;
		result[resultPos++] = (unsigned __int64)quotient; //结果

		outRemainder.data[pos] = (unsigned __int64)(dividend % divisor); //余数
	}
	pos--;

	while (pos >= 0)
	{
		//Console.WriteLine(pos);

		dividend = ((unsigned __int64)outRemainder.data[pos + 1] << 32) + (unsigned __int64)outRemainder.data[pos];
		unsigned __int64 quotient = dividend / divisor;
		result[resultPos++] = (unsigned int)quotient;

		outRemainder.data[pos + 1] = 0;
		outRemainder.data[pos--] = (unsigned int)(dividend % divisor);
		//Console.WriteLine(">>>> " + bi1);
	}

	outQuotient.dataLength = resultPos;
	int j = 0;
	for (int i = outQuotient.dataLength - 1; i >= 0; i--, j++)
		outQuotient.data[j] = result[i];
	for (; j < maxLength; j++)
		outQuotient.data[j] = 0;

	while (outQuotient.dataLength > 1 && outQuotient.data[outQuotient.dataLength - 1] == 0)
		outQuotient.dataLength--;

	if (outQuotient.dataLength == 0)
		outQuotient.dataLength = 1;

	while (outRemainder.dataLength > 1 && outRemainder.data[outRemainder.dataLength - 1] == 0)
		outRemainder.dataLength--;
}

void BigInteger::multiByteDivide(BigInteger &bi1, BigInteger &bi2,
					 BigInteger &outQuotient, BigInteger &outRemainder)
{
	//这个算法,我看了很久都没有看懂究竟是如何算出来的,大神如果看懂了的话,请告诉我!
	unsigned int result[maxLength];
	memset(result, 0, sizeof(unsigned int) * maxLength);
	int remainderLen = bi1.dataLength + 1;
	unsigned int *remainder = new unsigned int[remainderLen];
	memset(remainder, 0, sizeof(unsigned int) * remainderLen);

	unsigned int mask = 0x80000000;
	unsigned int val = bi2.data[bi2.dataLength - 1];
	int shift = 0, resultPos = 0;

	while (mask != 0 && (val & mask) == 0)
	{
		shift++; mask >>= 1;
	}

	for (int i = 0; i < bi1.dataLength; i++)
		remainder[i] = bi1.data[i];
	this->shiftLeft(remainder, remainderLen, shift);
	bi2 = bi2 << shift;

	int j = remainderLen - bi2.dataLength;
	int pos = remainderLen - 1;

	unsigned __int64 firstDivisorByte = bi2.data[bi2.dataLength - 1];
	unsigned __int64 secondDivisorByte = bi2.data[bi2.dataLength - 2];

	int divisorLen = bi2.dataLength + 1;
	unsigned int *dividendPart = new unsigned int[divisorLen];
	memset(dividendPart, 0, sizeof(unsigned int) * divisorLen);

	while (j > 0)
	{
		unsigned __int64 dividend = ((unsigned __int64)remainder[pos] << 32) + (unsigned __int64)remainder[pos - 1];

		unsigned __int64 q_hat = dividend / firstDivisorByte;
		unsigned __int64 r_hat = dividend % firstDivisorByte;

		bool done = false;
		while (!done)
		{
			done = true;

			if (q_hat == 0x100000000 ||
				(q_hat * secondDivisorByte) > ((r_hat << 32) + remainder[pos - 2]))
			{
				q_hat--;
				r_hat += firstDivisorByte;

				if (r_hat < 0x100000000)
					done = false;
			}
		}

		for (int h = 0; h < divisorLen; h++)
			dividendPart[h] = remainder[pos - h];

		BigInteger kk(dividendPart, divisorLen);
		BigInteger ss = bi2 * BigInteger((__int64)q_hat);

		while (ss > kk)
		{
			q_hat--;
			ss -= bi2;
		}
		BigInteger yy = kk - ss;

		for (int h = 0; h < divisorLen; h++)
			remainder[pos - h] = yy.data[bi2.dataLength - h];

		result[resultPos++] = (unsigned int)q_hat;

		pos--;
		j--;
	}

	outQuotient.dataLength = resultPos;
	int y = 0;
	for (int x = outQuotient.dataLength - 1; x >= 0; x--, y++)
		outQuotient.data[y] = result[x];
	for (; y < maxLength; y++)
		outQuotient.data[y] = 0;

	while (outQuotient.dataLength > 1 && outQuotient.data[outQuotient.dataLength - 1] == 0)
		outQuotient.dataLength--;

	if (outQuotient.dataLength == 0)
		outQuotient.dataLength = 1;

	outRemainder.dataLength = this->shiftRight(remainder, remainderLen, shift);

	for (y = 0; y < outRemainder.dataLength; y++)
		outRemainder.data[y] = remainder[y];
	for (; y < maxLength; y++)
		outRemainder.data[y] = 0;

	delete []remainder;
	delete []dividendPart;
}

int BigInteger::shiftRight(unsigned int buffer[], int bufferLen,int shiftVal)//右移shiftVal位
{
	int shiftAmount = 32;
	int invShift = 0;
	int bufLen = bufferLen;

	while (bufLen > 1 && buffer[bufLen - 1] == 0)
		bufLen--;//找出bufLen的实际长度

	for (int count = shiftVal; count > 0; )
	{
		if (count < shiftAmount)
		{
			shiftAmount = count;
			invShift = 32 - shiftAmount;
		}

		unsigned __int64 carry = 0;
		for (int i = bufLen - 1; i >= 0; i--)//从高位开始移动
		{
			unsigned __int64 val = ((unsigned __int64)buffer[i]) >> shiftAmount;
			val |= carry;  //按位求或

			carry = ((unsigned __int64)buffer[i]) << invShift;
			buffer[i] = (unsigned int)(val);
		}

		count -= shiftAmount;
	}

	while (bufLen > 1 && buffer[bufLen - 1] == 0)
		bufLen--;

	return bufLen;
}

BigInteger BigInteger::operator <<(int shiftVal)
{
	BigInteger result(*this);
	result.dataLength = shiftLeft(result.data, maxLength, shiftVal);

	return result;
}

int BigInteger::shiftLeft(unsigned int buffer[], int bufferLen, int shiftVal)//左移
{
	int shiftAmount = 32;
	int bufLen = bufferLen;

	while (bufLen > 1 && buffer[bufLen - 1] == 0)
		bufLen--;

	for (int count = shiftVal; count > 0; )
	{
		if (count < shiftAmount)
			shiftAmount = count;

		unsigned __int64 carry = 0;
		for (int i = 0; i < bufLen; i++)
		{
			unsigned __int64 val = ((unsigned __int64)buffer[i]) << shiftAmount;
			val |= carry;

			buffer[i] = (unsigned int)(val & 0xFFFFFFFF);
			carry = val >> 32;
		}

		if (carry != 0)
		{
			if (bufLen + 1 <= bufferLen)
			{
				buffer[bufLen] = (unsigned int)carry;
				bufLen++;
			}
		}
		count -= shiftAmount;
	}
	return bufLen;
}

bool BigInteger::operator <(BigInteger bi2)
{
	BigInteger bi1(*this);
	int pos = maxLength - 1;

	// bi1 is negative, bi2 is positive
	if ((bi1.data[pos] & 0x80000000) != 0 && (bi2.data[pos] & 0x80000000) == 0)
		return true;

	// bi1 is positive, bi2 is negative
	else if ((bi1.data[pos] & 0x80000000) == 0 && (bi2.data[pos] & 0x80000000) != 0)
		return false;

	// same sign
	int len = (bi1.dataLength > bi2.dataLength) ? bi1.dataLength : bi2.dataLength;
	for (pos = len - 1; pos >= 0 && bi1.data[pos] == bi2.data[pos]; pos--) ;

	if (pos >= 0)
	{
		if (bi1.data[pos] < bi2.data[pos])
			return true;
		return false;
	}
	return false;
}

BigInteger BigInteger::operator +=(BigInteger bi2)
{
	*this = *this + bi2;
	return *this;
}

BigInteger BigInteger::operator /(BigInteger bi2)
{
	BigInteger bi1(*this);
	BigInteger quotient;
	BigInteger remainder;

	int lastPos = maxLength - 1;
	bool divisorNeg = false, dividendNeg = false;

	if ((bi1.data[lastPos] & 0x80000000) != 0)     // bi1 negative
	{
		bi1 = -bi1;
		dividendNeg = true;
	}
	if ((bi2.data[lastPos] & 0x80000000) != 0)     // bi2 negative
	{
		bi2 = -bi2;
		divisorNeg = true;
	}

	if (bi1 < bi2)
	{
		return quotient;
	}

	else
	{
		if (bi2.dataLength == 1)
			singleByteDivide(bi1, bi2, quotient, remainder);
		else
			multiByteDivide(bi1, bi2, quotient, remainder);

		if (dividendNeg != divisorNeg)
			return -quotient;

		return quotient;
	}
}

BigInteger BigInteger::operator -=(BigInteger bi2)
{
	*this = *this - bi2;
	return *this;
}

BigInteger BigInteger::operator -(BigInteger bi2)
{
	BigInteger bi1(*this);
	BigInteger result;

	result.dataLength = (bi1.dataLength > bi2.dataLength) ? bi1.dataLength : bi2.dataLength;

	__int64 carryIn = 0;
	for (int i = 0; i < result.dataLength; i++)
	{
		__int64 diff;

		diff = (__int64)bi1.data[i] - (__int64)bi2.data[i] - carryIn;
		result.data[i] = (unsigned int)(diff & 0xFFFFFFFF);

		if (diff < 0)
			carryIn = 1;
		else
			carryIn = 0;
	}

	// roll over to negative
	if (carryIn != 0)
	{
		for (int i = result.dataLength; i < maxLength; i++)
			result.data[i] = 0xFFFFFFFF;
		result.dataLength = maxLength;
	}

	// fixed in v1.03 to give correct datalength for a - (-b)
	while (result.dataLength > 1 && result.data[result.dataLength - 1] == 0)
		result.dataLength--;

	// overflow check

	int lastPos = maxLength - 1;
	if ((bi1.data[lastPos] & 0x80000000) != (bi2.data[lastPos] & 0x80000000) &&
		(result.data[lastPos] & 0x80000000) != (bi1.data[lastPos] & 0x80000000))
	{
		assert(false);
	}

	return result;
}

string BigInteger::DecToHex(unsigned int value, string format)//进制转换
{
	string HexStr;
	int a[100]; 
	int i = 0; 
	int m = 0;
	int mod = 0; 
	char hex[16]={'0','1','2','3','4','5','6','7','8','9','A','B','C','D','E','F'};
	while(value > 0) 
	{ 
		mod = value % 16; 
		a[i++] = mod; 
		value = value/16; 

	} 

	for(i = i - 1; i >= 0; i--)
	{ 
		m=a[i];
		HexStr.push_back(hex[m]);
	} 

	while (format == string("X8") && HexStr.size() < 8)
	{
		HexStr = "0" + HexStr;
	}

	return HexStr;
}

string BigInteger::ToHexString()
{
	string result = DecToHex(data[dataLength - 1], string("X"));

	for (int i = dataLength - 2; i >= 0; i--)
	{
		result += DecToHex(data[i], string("X8"));
	}

	return result;
}
加密的头文件:

#pragma once
#include "BigInteger.h"
#include <cmath>
#include <string>

using namespace std;

class MdRSACrypto
{
	typedef unsigned char byte;
public:
	MdRSACrypto(void);
	void Encrypt(byte data[], int dataLen, byte *&secData, int &secDataLen);
	void Decrypt(byte data[], int dataLen, byte *&srcData, int &srcDataLen);
	void BigIntegerToBytes(BigInteger bi, byte *&secData, int &secDataLen);
	byte htoi(const char *str);
public:
	~MdRSACrypto(void);
public:
	BigInteger ParamD;
	BigInteger ParamE;
	BigInteger ParamN;
};


加密类的实现:

#include "StdAfx.h"
#include "MdRSACrypto.h"

MdRSACrypto::MdRSACrypto(void)
: ParamN(string("9925267379821871425510845582445759503537677132811205937337767819243983130290431496536905675097172626108809849063099490383615206095849973923022381920579193"), 10)
, ParamD(string("5644942460824217305807314808246519826955131459745825064823270702643487645914327490405459900282526319197371417079551956432087864023339947421172316448189099"), 10)
, ParamE(string("131363"), 10) // 默认N E D 512位加密,这里的两个素数是直接给出的,省去了计算素数以及素性检验的步骤
{
}

MdRSACrypto::~MdRSACrypto(void)
{
}

void MdRSACrypto::Encrypt(byte data[], int dataLen, byte *&secData, int &secDataLen)
{
	BigInteger decryptData(data, dataLen);
	BigInteger encryptData = decryptData.modPow(ParamE, ParamN);
	BigIntegerToBytes(encryptData, secData, secDataLen);
}

void MdRSACrypto::Decrypt(byte data[], int dataLen, byte *&srcData, int &srcDataLen)
{
	BigInteger encryptData(data, dataLen);
	BigInteger decryptData = encryptData.modPow(ParamD, ParamN);
	BigIntegerToBytes(decryptData, srcData, srcDataLen);
}

void MdRSACrypto::BigIntegerToBytes(BigInteger bi, byte *&secData, int &secDataLen)
{
	string str = bi.ToHexString();
	secDataLen = ceil(str.size() / 2.0);
	secData = new byte[secDataLen];
	memset(secData, 0, sizeof(byte) * secDataLen);
	int mod = str.size() % 2;
	if (mod == 0)
		secData[0] = htoi(str.substr(0, 2).c_str());
	else
		secData[0] = htoi(str.substr(0, 1).c_str());

	for (int i = 1; i < secDataLen ; i++)
	{
		secData[i] = htoi(str.substr((i * 2 - mod), 2).c_str());
	}
}

MdRSACrypto::byte MdRSACrypto::htoi(const char *str)
{   
	byte dec = 0, t;  
	while(t = *str++)   
	{   
		dec <<= 4;   
		if(t < 58)   
			t -= 48;   
		if(t > 64 && t < 71)   
			t -= 55;   
		if(t > 96 && t < 103)   
			t -= 87;   
		dec |= t; 
	}   
	return dec;
}


测试文件:

// CppRsaTest.cpp : Defines the entry point for the console application.
//

#include "stdafx.h"
#include <iostream>
#include "BigInteger.h"
#include "MdRSACrypto.h"
#include <string>

using namespace std;

int main()
{
	string message = "I LOVE YOU SO!!!"; // 要加密信息

	MdRSACrypto rsa;

	unsigned char *secData = NULL;
	int secDataLen = 0;

	rsa.Encrypt((unsigned char *)message.c_str(), message.size() + 1, secData, secDataLen);

	BigInteger secInt(secData, secDataLen);
	cout << "密文数据:" << secInt.ToHexString() << endl;

	unsigned char *decData = NULL;
	int decDataLen = 0;
	
	rsa.Decrypt(secData, secDataLen, decData, decDataLen);

	BigInteger decInt(decData, decDataLen);
	cout << "解密数据:" << decInt.ToHexString() << endl;

	string decMessage = (char *)decData;
	cout << "解密明文:" << decMessage << endl;

	if (secData != NULL)
	{
		delete []secData;
	}

	if (decData != NULL)
	{
		delete []decData;
	}

	return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: