`
bianku
  • 浏览: 69893 次
  • 性别: Icon_minigender_1
  • 来自: 常州
社区版块
存档分类
最新评论

并行排序算法

阅读更多
今天早晨看到 蛙蛙池塘 的这篇博客 谁能把这个程序的性能提升一倍?---并行排序算法 。促使我写了一个并行排序算法,这个排序算法充分利用多核CPU进行并行计算,从而提高排序的效率。 

先简单说一下蛙蛙池塘 给的A,B,C 三种算法(见上面引用的那篇博客),A算法将耗时的平方和开平方计算放到比较函数中,导致Array.Sort 时,每次亮亮比较都要执行平方和开平方计算,其平均算法复杂度为 O(nlog2n) 。 而B 将平方和开平方计算提取出来,算法复杂度降低到 O(n) ,这也就是为什么B比A效率要高很多的缘故。C 和 B 相比,将平方函数替换成了 x*x ,由于少了远程函数调用和Pow函数本身的开销,效率有提高了不少。我在C的基础上编写了D算法,D算法采用并行计算技术,在我的双核笔记本电脑上数据量比较大的情况下,其排序效率较C要提高30%左右。 

下面重点介绍这个并行排序算法。算法思路其实很简单,就是将要排序的数组按照处理器数量等分成若干段,然后用和处理器数量等同的线程并行对各个小段进行排序,排序结束和,再在单一线程中对这若干个已经排序的小段进行归并排序,最后输出完整的排序结果。考虑到和.Net 2.0 兼容,我没有用微软提供的并行库,而是用多线程来实现。 

下面是测试结果: 



n A B C D 
32768 0.7345 0.04122 0.0216 0.0254 
65535 1.5464 0.08863 0.05139 0.05149 
131072 3.2706 0.1858 0.118 0.108 
262144 6.8423 0.4056 0.29586 0.21849 
524288 15.0342 0.9689 0.7318 0.4906 
1048576 31.6312 1.9978 1.4646 1.074 
2097152 66.9134 4.1763 3.0828 2.3095 



从测试结果上看,当要排序的数组长度较短时,并行排序的效率甚至还没有不进行并行排序高,这主要是多线程的开销造成的。当数组长度增大到25万以上时,并行排序的优势开始体现出来,随着数组长度的增长,排序时间最后基本稳定在但线程排序时间的 74% 左右,其中并行排序的消耗大概在50%左右,归并排序的消耗在 14%左右。由此也可以推断,如果在4CPU的机器上,其排序时间最多可以减少到单线程的 14 + 25 = 39%。8 CPU 为 14 + 12.5 = 26.5% 

目前这个算法在归并算法上可能还有提高的余地,如果哪位高手能够进一步提高这个算法,不妨贴出来一起交流交流。 

下面分别给出并行排序和归并排序的代码: 

并行排序类 ParallelSort 

Paralletsort 类是一个通用的泛型,调用起来非常简单,下面给一个简单的int型数组的排序示例: 



class IntComparer : IComparer < int > 
{ 
IComparer Members #region IComparer<int> Members 

public int Compare( int x, int y) 
{ 
return x.CompareTo(y); 
} 

#endregion 
} 

public void SortInt( int [] array) 
{ 
Sort.ParallelSort < int > parallelSort = new Sort.ParallelSort < int > (); 
parallelSort.Sort(array, new IntComparer()); 
} 

只要实现一个T类型两两比较的接口,然后调用ParallelSort 的 Sort 方法就可以了,是不是很简单? 

下面是 ParallelSort类的代码 

using System; 
using System.Collections.Generic; 
using System.Linq; 
using System.Text; 
using System.Threading; 

namespace Sort 
{ 
/**/ /// <summary> 
/// ParallelSort 
/// </summary> 
/// <typeparam name="T"></typeparam> 
public class ParallelSort < T > 
{ 
enum Status 
{ 
Idle = 0 , 
Running = 1 , 
Finish = 2 , 
} 

class ParallelEntity 
{ 
public Status Status; 
public T[] Array; 
public IComparer < T > Comparer; 

public ParallelEntity(Status status, T[] array, IComparer < T > comparer) 
{ 
Status = status; 
Array = array; 
Comparer = comparer; 
} 
} 

private void ThreadProc(Object stateInfo) 
{ 
ParallelEntity pe = stateInfo as ParallelEntity; 

lock (pe) 
{ 
pe.Status = ParallelSort < T > .Status.Running; 

Array.Sort(pe.Array, pe.Comparer); 

pe.Status = ParallelSort < T > .Status.Finish; 
} 
} 

public void Sort(T[] array, IComparer < T > comparer) 
{ 
// Calculate process count 
int processorCount = Environment.ProcessorCount; 

// If array.Length too short, do not use Parallel sort 
if (processorCount == 1 || array.Length < processorCount) 
{ 
Array.Sort(array, comparer); 
return ; 
} 

// Split array 
ParallelEntity[] partArray = new ParallelEntity[processorCount]; 

int remain = array.Length; 
int partLen = array.Length / processorCount; 

// Copy data to splited array 
for ( int i = 0 ; i < processorCount; i ++ ) 
{ 
if (i == processorCount - 1 ) 
{ 
partArray[i] = new ParallelEntity(Status.Idle, new T[remain], comparer); 
} 
else 
{ 
partArray[i] = new ParallelEntity(Status.Idle, new T[partLen], comparer); 

remain -= partLen; 
} 

Array.Copy(array, i * partLen, partArray[i].Array, 0 , partArray[i].Array.Length); 
} 

// Parallel sort 
for ( int i = 0 ; i < processorCount - 1 ; i ++ ) 
{ 
ThreadPool.QueueUserWorkItem( new WaitCallback(ThreadProc), partArray[i]); 
} 

ThreadProc(partArray[processorCount - 1 ]); 
} 

private static void A(Vector[] vectors) 
{ 
Array.Sort(vectors, new VectorComparer()); 
} 

private static void B(Vector[] vectors) 
{ 
int n = vectors.Length; 
for ( int i = 0 ; i < n; i ++ ) 
{ 
Vector c1 = vectors[i]; 
c1.T = Math.Sqrt(Math.Pow(c1.X, 2 ) 
+ Math.Pow(c1.Y, 2 ) 
+ Math.Pow(c1.Z, 2 ) 
+ Math.Pow(c1.W, 2 )); 
} 
Array.Sort(vectors, new VectorComparer2()); 
} 

private static void C(Vector[] vectors) 
{ 
int n = vectors.Length; 
for ( int i = 0 ; i < n; i ++ ) 
{ 
Vector c1 = vectors[i]; 
c1.T = Math.Sqrt(c1.X * c1.X 
+ c1.Y * c1.Y 
+ c1.Z * c1.Z 
+ c1.W * c1.W); 
} 
Array.Sort(vectors, new VectorComparer2()); 
} 

private static void D(Vector[] vectors) 
{ 
int n = vectors.Length; 
for ( int i = 0 ; i < n; i ++ ) 
{ 
Vector c1 = vectors[i]; 
c1.T = Math.Sqrt(c1.X * c1.X 
+ c1.Y * c1.Y 
+ c1.Z * c1.Z 
+ c1.W * c1.W); 
} 

Sort.ParallelSort < Vector > parallelSort = new Sort.ParallelSort < Vector > (); 
parallelSort.Sort(vectors, new VectorComparer2()); 
} 

} 
}  
 
// Wait all threads finish 
for ( int i = 0 ; i < processorCount; i ++ ) 
{ 
while ( true ) 
{ 
lock (partArray[i]) 
{ 
if (partArray[i].Status == ParallelSort < T > .Status.Finish) 
{ 
break ; 
} 
} 

Thread.Sleep( 0 ); 
} 
} 

// Merge sort 
MergeSort < T > mergeSort = new MergeSort < T > (); 

List < T[] > source = new List < T[] > (processorCount); 

foreach (ParallelEntity pe in partArray) 
{ 
source.Add(pe.Array); 
} 

mergeSort.Sort(array, source, comparer); 
} 
} 
} 



多路归并排序类 MergeSort 

using System; 
using System.Collections.Generic; 
using System.Linq; 
using System.Text; 

namespace Sort 
{ 
/**/ /// <summary> 
/// MergeSort 
/// </summary> 
/// <typeparam name="T"></typeparam> 
public class MergeSort < T > 
{ 
public void Sort(T[] destArray, List < T[] > source, IComparer < T > comparer) 
{ 
// Merge Sort 
int [] mergePoint = new int [source.Count]; 

for ( int i = 0 ; i < source.Count; i ++ ) 
{ 
mergePoint[i] = 0 ; 
} 

int index = 0 ; 

while (index < destArray.Length) 
{ 
int min = - 1 ; 

for ( int i = 0 ; i < source.Count; i ++ ) 
{ 
if (mergePoint[i] >= source[i].Length) 
{ 
continue ; 
} 

if (min < 0 ) 
{ 
min = i; 
} 
else 
{ 
if (comparer.Compare(source[i][mergePoint[i]], source[min][mergePoint[min]]) < 0 ) 
{ 
min = i; 
} 
} 
} 

if (min < 0 ) 
{ 
continue ; 
} 

destArray[index ++ ] = source[min][mergePoint[min]]; 
mergePoint[min] ++ ; 
} 
} 

} 
} 



主函数及测试代码 在蛙蛙池塘 代码基础上修改 



using System; 
using System.Collections.Generic; 
using System.Diagnostics; 

namespace Vector4Test 
{ 
public class Vector 
{ 
public double W; 
public double X; 
public double Y; 
public double Z; 
public double T; 
} 

internal class VectorComparer : IComparer < Vector > 
{ 
public int Compare(Vector c1, Vector c2) 
{ 
if (c1 == null || c2 == null ) 
throw new ArgumentNullException( " Both objects must not be null " ); 
double x = Math.Sqrt(Math.Pow(c1.X, 2 ) 
+ Math.Pow(c1.Y, 2 ) 
+ Math.Pow(c1.Z, 2 ) 
+ Math.Pow(c1.W, 2 )); 
double y = Math.Sqrt(Math.Pow(c2.X, 2 ) 
+ Math.Pow(c2.Y, 2 ) 
+ Math.Pow(c2.Z, 2 ) 
+ Math.Pow(c2.W, 2 )); 
if (x > y) 
return 1 ; 
else if (x < y) 
return - 1 ; 
else 
return 0 ; 
} 
} 

internal class VectorComparer2 : IComparer < Vector > 
{ 
public int Compare(Vector c1, Vector c2) 
{ 
if (c1 == null || c2 == null ) 
throw new ArgumentNullException( " Both objects must not be null " ); 
if (c1.T > c2.T) 
return 1 ; 
else if (c1.T < c2.T) 
return - 1 ; 
else 
return 0 ; 
} 
} 

internal class Program 
{ 
private static void Print(Vector[] vectors) 
{ 
// foreach (Vector v in vectors) 
// { 
// Console.WriteLine(v.T); 
// } 
} 

private static void Main( string [] args) 
{ 
Vector[] vectors = GetVectors(); 

Console.WriteLine( string .Format( " n = {0} " , vectors.Length)); 

Stopwatch watch1 = new Stopwatch(); 
watch1.Start(); 
A(vectors); 
watch1.Stop(); 
Console.WriteLine( " A sort time: " + watch1.Elapsed); 
Print(vectors); 

vectors = GetVectors(); 
watch1.Reset(); 
watch1.Start(); 
B(vectors); 
watch1.Stop(); 
Console.WriteLine( " B sort time: " + watch1.Elapsed); 
Print(vectors); 

vectors = GetVectors(); 
watch1.Reset(); 
watch1.Start(); 
C(vectors); 
watch1.Stop(); 
Console.WriteLine( " C sort time: " + watch1.Elapsed); 
Print(vectors); 

vectors = GetVectors(); 
watch1.Reset(); 
watch1.Start(); 
D(vectors); 
watch1.Stop(); 
Console.WriteLine( " D sort time: " + watch1.Elapsed); 
Print(vectors); 

Console.ReadKey(); 
} 

private static Vector[] GetVectors() 
{ 
int n = 1 << 21 ; 
Vector[] vectors = new Vector[n]; 
Random random = new Random(); 

for ( int i = 0 ; i < n; i ++ ) 
{ 
vectors[i] = new Vector(); 
vectors[i].X = random.NextDouble(); 
vectors[i].Y = random.NextDouble(); 
vectors[i].Z = random.NextDouble(); 
vectors[i].W = random.NextDouble(); 
} 
return vectors; 

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics