00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #ifndef MULTITHREADEDALGORITHM_H
00011 #define MULTITHREADEDALGORITHM_H
00012
00013 #include <string>
00014 #include <vector>
00015 #include <math.h>
00016 #include "bmutex.h"
00017 #include "bthread.h"
00018 #include "bthread_signal.h"
00019 #include "DMutex.h"
00020 #include "EnumWrapper.h"
00021 #include "MessageLogResource.h"
00022
00023 #include <numeric>
00024 #include <algorithm>
00025 #include <sstream>
00026 #include "DesktopServices.h"
00027
00028 class Progress;
00029 class MessageLogMgr;
00030
00031
00032
00033
00034 namespace mta
00035 {
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045 unsigned int getNumRequiredThreads(unsigned int dataSize);
00046
00047
00048
00049
00050 enum ResultEnum { SUCCESS, FAILURE, ABORT };
00051
00052
00053
00054
00055 typedef EnumWrapper<ResultEnum> Result;
00056
00057
00058
00059
00060 class ThreadCommand
00061 {
00062 public:
00063
00064
00065
00066 virtual void run() = 0;
00067 };
00068
00069
00070
00071
00072 class ThreadReporter
00073 {
00074 public:
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084 virtual Result reportProgress(int threadIndex, int percentDone) = 0;
00085
00086
00087
00088
00089
00090
00091
00092
00093 virtual Result reportCompletion(int threadIndex) = 0;
00094
00095
00096
00097
00098
00099
00100
00101
00102 virtual Result reportError(std::string errorText) = 0;
00103
00104
00105
00106
00107
00108
00109 virtual std::string getErrorText() const = 0;
00110
00111
00112
00113
00114
00115
00116
00117
00118 virtual int getProgress(int threadIndex) const = 0;
00119
00120
00121
00122
00123
00124
00125
00126 virtual void runInMainThread(ThreadCommand& command) = 0;
00127 };
00128
00129
00130
00131
00132 class MultiThreadReporter : public ThreadReporter
00133 {
00134 public:
00135
00136
00137
00138 enum ReportTypeEnum
00139 {
00140 THREAD_NO_REPORT = 0x0,
00141 THREAD_PROGRESS = 0x1,
00142 THREAD_ERROR = 0x2,
00143 THREAD_COMPLETE = 0x4,
00144 THREAD_WORK = 0x8
00145 };
00146
00147
00148
00149
00150 typedef EnumWrapper<ReportTypeEnum> ReportType;
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168 MultiThreadReporter(int threadCount, Result* pResult, BMutex& mutexA, BThreadSignal& signalA, BMutex& mutexB,
00169 BThreadSignal& signalB);
00170
00171
00172
00173
00174 virtual ~MultiThreadReporter() {};
00175
00176
00177
00178
00179
00180
00181
00182 MultiThreadReporter(MultiThreadReporter& reporter);
00183
00184
00185
00186
00187 Result reportProgress(int threadIndex, int percentDone);
00188
00189
00190
00191
00192 Result reportCompletion(int threadIndex);
00193
00194
00195
00196
00197 Result reportError(std::string errorText);
00198
00199
00200
00201
00202
00203
00204 int getProgress() const;
00205
00206
00207
00208
00209 int getProgress(int threadIndex) const;
00210
00211
00212
00213
00214 std::string getErrorText() const;
00215
00216
00217
00218
00219 void runInMainThread(ThreadCommand& command);
00220
00221
00222
00223
00224
00225
00226
00227 void setReportType(ReportType type);
00228
00229
00230
00231
00232
00233
00234 unsigned int getReportType() const;
00235
00236
00237
00238
00239
00240
00241 ThreadCommand* getThreadCommand();
00242
00243 private:
00244 MultiThreadReporter& operator=(const MultiThreadReporter& rhs);
00245
00246 BMutex& mMutexA;
00247 BThreadSignal& mSignalA;
00248 BMutex& mMutexB;
00249 BThreadSignal& mSignalB;
00250 Result* mpResult;
00251 std::vector<int> mThreadProgress;
00252 std::string mErrorMessage;
00253 unsigned int mReportType;
00254 ThreadCommand* mpThreadCommand;
00255 mutable DMutex mReporterMutex;
00256 mutable DMutex mSignalMutex;
00257
00258 Result signalMainThread(ThreadCommand& reportStatus, ReportType type);
00259 };
00260
00261
00262
00263 #if defined(WIN_API)
00264 #pragma warning (push)
00265 #pragma warning (disable: 4355)
00266 #endif
00267
00268
00269
00270
00271 class AlgorithmThread : public ThreadCommand
00272 {
00273 public:
00274
00275
00276
00277
00278
00279
00280
00281
00282 AlgorithmThread(int threadIndex, ThreadReporter& reporter) :
00283 mpAlgorithmMutex(NULL),
00284 mReporter(reporter),
00285 mThreadHandle(static_cast<void*>(this), reinterpret_cast<void*>(AlgorithmThread::threadFunction)),
00286 mThreadIndex(threadIndex) {}
00287
00288
00289
00290
00291 virtual ~AlgorithmThread() {};
00292
00293
00294
00295
00296
00297
00298
00299 AlgorithmThread(const AlgorithmThread& thread) :
00300 mpAlgorithmMutex(thread.mpAlgorithmMutex),
00301 mReporter(thread.mReporter),
00302 mThreadHandle(static_cast<void*>(this), reinterpret_cast<void*>(AlgorithmThread::threadFunction)),
00303 mThreadIndex(thread.mThreadIndex) {}
00304
00305
00306
00307
00308
00309
00310
00311 static void threadFunction(AlgorithmThread* pThreadData);
00312
00313
00314
00315
00316 virtual void run() = 0;
00317
00318
00319
00320
00321
00322
00323 bool launch();
00324
00325
00326
00327
00328
00329
00330 bool wait();
00331
00332
00333
00334
00335
00336
00337
00338 void runInMainThread(ThreadCommand& command);
00339
00340
00341
00342
00343
00344
00345
00346
00347
00348 void setAlgorithmMutex(DMutex* pMutex);
00349
00350
00351
00352
00353
00354
00355
00356
00357 void waitForAlgorithmLoop();
00358
00359
00360
00361
00362 class Range
00363 {
00364 public:
00365
00366
00367
00368 Range() :
00369 mFirst(0),
00370 mLast(0)
00371 {
00372 }
00373
00374
00375
00376
00377
00378
00379 int computePercent(int index)
00380 {
00381 return (100 * (index - mFirst)) / (mLast - mFirst + 1);
00382 }
00383
00384
00385
00386
00387 int mFirst;
00388
00389
00390
00391
00392 int mLast;
00393 };
00394
00395 protected:
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405 Range getThreadRange(int threadCount, int dataSize) const;
00406
00407
00408
00409
00410
00411
00412 int getThreadIndex() const;
00413
00414
00415
00416
00417
00418
00419 ThreadReporter& getReporter() const;
00420
00421 private:
00422 DMutex* mpAlgorithmMutex;
00423 ThreadReporter& mReporter;
00424 BThread mThreadHandle;
00425 int mThreadIndex;
00426 };
00427
00428 #if defined(WIN_API)
00429 #pragma warning (pop)
00430 #endif
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459
00460 class ProgressReporter
00461 {
00462 public:
00463
00464
00465
00466
00467
00468
00469 virtual void reportProgress(int percent) = 0;
00470
00471
00472
00473
00474
00475
00476
00477 virtual void reportError(const std::string& text) = 0;
00478 };
00479
00480
00481
00482
00483 class ProgressObjectReporter : public ProgressReporter
00484 {
00485 public:
00486
00487
00488
00489
00490
00491
00492
00493
00494 ProgressObjectReporter(std::string baseMessage, Progress* pProgress) :
00495 mMessage(baseMessage),
00496 mpProgress(pProgress)
00497 {
00498 }
00499
00500
00501
00502
00503 virtual ~ProgressObjectReporter() {};
00504
00505
00506
00507
00508 void reportProgress(int percent);
00509
00510
00511
00512
00513 void reportError(const std::string& text);
00514
00515 private:
00516 std::string mMessage;
00517 Progress* mpProgress;
00518 };
00519
00520
00521
00522
00523 class StatusBarReporter : public ProgressReporter
00524 {
00525 public:
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536 StatusBarReporter(std::string baseMessage, const std::string& component, const std::string& key) :
00537 mMessage(baseMessage),
00538 mComponent(component),
00539 mKey(key)
00540 {}
00541
00542
00543
00544
00545 virtual ~StatusBarReporter() {};
00546
00547
00548
00549
00550 void reportProgress(int percent)
00551 {
00552 std::stringstream buf;
00553 buf << mMessage << ": " << percent << "%";
00554 Service<DesktopServices>()->setStatusBarMessage(buf.str());
00555 }
00556
00557
00558
00559
00560 void reportError(const std::string& text)
00561 {
00562 Service<DesktopServices>()->setStatusBarMessage(text);
00563 MessageResource msg("Error", mComponent, mKey);
00564 msg->addProperty("Message", text);
00565 }
00566 private:
00567 StatusBarReporter& operator=(const StatusBarReporter& rhs);
00568
00569 std::string mMessage;
00570 const std::string& mComponent;
00571 const std::string& mKey;
00572 };
00573
00574
00575
00576
00577
00578 class MultiPhaseProgressReporter : public ProgressReporter
00579 {
00580 public:
00581
00582
00583
00584
00585
00586
00587
00588
00589 MultiPhaseProgressReporter(ProgressReporter& base, const std::vector<int>& phaseWeights) :
00590 mReporter(base), mPhaseWeights(phaseWeights), mCurrentPhase(0) {}
00591
00592
00593
00594
00595 virtual ~MultiPhaseProgressReporter() {};
00596
00597
00598
00599
00600 void reportProgress(int percent);
00601
00602
00603
00604
00605 void reportError(const std::string& text);
00606
00607
00608
00609
00610
00611
00612
00613 void setCurrentPhase(int phase);
00614
00615
00616
00617
00618
00619
00620 int getCurrentPhase() const;
00621
00622 private:
00623 MultiPhaseProgressReporter& operator=(const MultiPhaseProgressReporter& rhs);
00624
00625 int convertPhaseProgressToTotalProgress(int phaseProgress);
00626
00627 ProgressReporter& mReporter;
00628 std::vector<int> mPhaseWeights;
00629 int mCurrentPhase;
00630 };
00631
00632
00633
00634
00635 template<class AlgInput, class AlgOutput, class AlgThread>
00636 class MultiThreadedAlgorithm
00637 {
00638 public:
00639
00640
00641
00642
00643
00644
00645
00646
00647
00648
00649
00650
00651 MultiThreadedAlgorithm(int threadCount, const AlgInput& input, AlgOutput& output, ProgressReporter* pProgress);
00652
00653
00654
00655
00656 ~MultiThreadedAlgorithm();
00657
00658
00659
00660
00661
00662
00663 Result run();
00664
00665
00666
00667
00668
00669
00670
00671 std::string getErrorText() const
00672 {
00673 return mErrorText;
00674 }
00675
00676 private:
00677 MultiThreadedAlgorithm& operator=(const MultiThreadedAlgorithm& rhs);
00678
00679 Result createThreads(int threadCount);
00680 Result startAllThreads();
00681 Result waitForThreadsToComplete();
00682 int processCurrentReports(int percentDone);
00683 int processReport(unsigned int currentType, int percentDone);
00684 Result compileResults();
00685
00686 Result mCurrentStatus;
00687 const AlgInput& mInput;
00688 AlgOutput& mOutput;
00689 std::vector<AlgThread*> mThreads;
00690 MultiThreadReporter* mpThreadReporter;
00691 ProgressReporter* mpProgressReporter;
00692 DMutex mMutexA;
00693 DThreadSignal mSignalA;
00694 DMutex mMutexB;
00695 DThreadSignal mSignalB;
00696 std::string mErrorText;
00697 };
00698
00699 template<class AlgInput, class AlgOutput, class AlgThread>
00700 MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::MultiThreadedAlgorithm(int threadCount,
00701 const AlgInput& algInput, AlgOutput& algOutput, ProgressReporter* pReporter) :
00702 mCurrentStatus(SUCCESS),
00703 mInput(algInput),
00704 mOutput(algOutput),
00705 mpThreadReporter(NULL),
00706 mpProgressReporter(pReporter)
00707 {
00708 mpThreadReporter = new MultiThreadReporter(threadCount, &mCurrentStatus, mMutexA, mSignalA, mMutexB, mSignalB);
00709 createThreads(threadCount);
00710 }
00711
00712 template<class AlgInput, class AlgOutput, class AlgThread>
00713 MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::~MultiThreadedAlgorithm()
00714 {
00715 typename std::vector<AlgThread*>::iterator iter;
00716 for (iter = mThreads.begin(); iter != mThreads.end(); ++iter)
00717 {
00718 AlgThread* pThread = *iter;
00719 if (pThread != NULL)
00720 {
00721 delete pThread;
00722 }
00723 }
00724
00725 mThreads.clear();
00726 delete mpThreadReporter;
00727 }
00728
00729 template<class AlgInput, class AlgOutput, class AlgThread>
00730 Result MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::createThreads(int threadCount)
00731 {
00732 int i;
00733 for (i = 0; i < threadCount; ++i)
00734 {
00735 AlgThread* pThread = NULL;
00736 pThread = new AlgThread(mInput, threadCount, i, *mpThreadReporter);
00737 if (pThread != NULL)
00738 {
00739 pThread->setAlgorithmMutex(&mMutexA);
00740 mThreads.push_back(pThread);
00741 }
00742 }
00743 return SUCCESS;
00744 }
00745
00746 template<class AlgInput, class AlgOutput, class AlgThread>
00747 Result MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::startAllThreads()
00748 {
00749 typename std::vector<AlgThread*>::iterator iter;
00750
00751 mMutexA.MutexLock();
00752 mMutexB.MutexLock();
00753
00754 for (iter = mThreads.begin(); iter != mThreads.end(); ++iter)
00755 {
00756 (*iter)->launch();
00757 }
00758 return SUCCESS;
00759 }
00760
00761 template<class AlgInput, class AlgOutput, class AlgThread>
00762 Result MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::waitForThreadsToComplete()
00763 {
00764 bool doneProcessing = false;
00765 int percentDone = 0;
00766
00767 mMutexB.MutexUnlock();
00768 while (!doneProcessing)
00769 {
00770 mSignalA.ThreadSignalWait(&mMutexA);
00771
00772 percentDone = processCurrentReports(percentDone);
00773 doneProcessing = (percentDone == 100 || mCurrentStatus != SUCCESS);
00774 if (doneProcessing)
00775 {
00776 mMutexA.MutexUnlock();
00777
00778 typename std::vector<AlgThread*>::iterator iter;
00779 for (iter = mThreads.begin(); iter != mThreads.end(); ++iter)
00780 {
00781 (*iter)->wait();
00782 }
00783 }
00784 }
00785
00786 return mCurrentStatus;
00787 }
00788
00789 template<class AlgInput, class AlgOutput, class AlgThread>
00790 int MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::processCurrentReports(int percentDone)
00791 {
00792 mMutexB.MutexLock();
00793
00794 int type = mpThreadReporter->getReportType();
00795 unsigned int currentType = MultiThreadReporter::THREAD_WORK;
00796 while (currentType != 0)
00797 {
00798 if (type & currentType)
00799 {
00800 percentDone = processReport(currentType, percentDone);
00801 }
00802 currentType /= 2;
00803 }
00804
00805 mpThreadReporter->setReportType(MultiThreadReporter::THREAD_NO_REPORT);
00806
00807 mSignalB.ThreadSignalActivate();
00808 mMutexB.MutexUnlock();
00809
00810 return percentDone;
00811 }
00812
00813 template<class AlgInput, class AlgOutput, class AlgThread>
00814 int MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::processReport(unsigned int currentType, int percentDone)
00815 {
00816 switch (currentType)
00817 {
00818 case MultiThreadReporter::THREAD_NO_REPORT:
00819 break;
00820 case MultiThreadReporter::THREAD_COMPLETE:
00821 case MultiThreadReporter::THREAD_PROGRESS:
00822 percentDone = mpThreadReporter->getProgress();
00823 if (mpProgressReporter != NULL)
00824 {
00825 mpProgressReporter->reportProgress(percentDone);
00826 }
00827 break;
00828 case MultiThreadReporter::THREAD_ERROR:
00829 if (mpProgressReporter != NULL)
00830 {
00831 mpProgressReporter->reportError(mpThreadReporter->getErrorText().c_str());
00832 }
00833 mErrorText = mpThreadReporter->getErrorText().c_str();
00834 mCurrentStatus = FAILURE;
00835 break;
00836 case MultiThreadReporter::THREAD_WORK:
00837 if (mpThreadReporter->getThreadCommand() != NULL)
00838 {
00839 mpThreadReporter->getThreadCommand()->run();
00840 }
00841 break;
00842 default:
00843 break;
00844 }
00845
00846 return percentDone;
00847 }
00848
00849 template<class AlgInput, class AlgOutput, class AlgThread>
00850 Result MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::compileResults()
00851 {
00852 bool success = mOutput.compileOverallResults(mThreads);
00853 return ((success == true) ? SUCCESS : FAILURE);
00854 }
00855
00856 template<class AlgInput, class AlgOutput, class AlgThread>
00857 Result MultiThreadedAlgorithm<AlgInput, AlgOutput, AlgThread>::run()
00858 {
00859 Result result = startAllThreads();
00860 if (result == SUCCESS)
00861 {
00862 result = waitForThreadsToComplete();
00863 }
00864
00865 if (result == SUCCESS)
00866 {
00867 result = compileResults();
00868 }
00869
00870 return result;
00871 }
00872
00873 }
00874
00875 #endif
00876