Skip to content

Commit 8fabb36

Browse files
authored
Merge pull request #48 from bmcnns/main
Adds a warm start mechanism to the Run method to allow for online/batch training
2 parents 2b00acd + b6de188 commit 8fabb36

File tree

5 files changed

+47
-22
lines changed

5 files changed

+47
-22
lines changed

include/operon/algorithms/ga_base.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,25 @@ class GeneticAlgorithmBase {
6060
[[nodiscard]] auto Elapsed() const -> double { return elapsed_; }
6161
auto Elapsed() -> double& { return elapsed_; }
6262

63+
[[nodiscard]] auto IsFitted() const -> bool { return isFitted_; }
64+
auto IsFitted() -> bool& { return isFitted_; }
65+
6366
auto Reset() -> void
6467
{
6568
generation_ = 0;
6669
elapsed_ = 0;
6770
GetGenerator()->Evaluator()->Reset();
6871
}
6972

73+
auto RestoreIndividuals(std::vector<Individual> inds) -> void
74+
{
75+
EXPECT(inds.size() == config_.PoolSize + config_.PopulationSize,
76+
"Mismatched number of individuals (must match pool/population sizes)");
77+
individuals_ = std::move(inds);
78+
parents_ = Operon::Span<Individual>(individuals_.data(), config_.PoolSize);
79+
offspring_ = Operon::Span<Individual>(individuals_.data() + config_.PoolSize, config_.PopulationSize);
80+
}
81+
7082
private:
7183
GeneticAlgorithmConfig config_;
7284

@@ -82,6 +94,7 @@ class GeneticAlgorithmBase {
8294

8395
size_t generation_{0};
8496
double elapsed_{0}; // elapsed time in microseconds
97+
bool isFitted_{false};
8598
};
8699

87100
} // namespace Operon

include/operon/algorithms/gp.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class OPERON_EXPORT GeneticProgrammingAlgorithm : public GeneticAlgorithmBase {
3434
{
3535
}
3636

37-
auto Run(tf::Executor& /*executor*/, Operon::RandomGenerator&/*rng*/, std::function<void()> /*report*/ = nullptr) -> void;
38-
auto Run(Operon::RandomGenerator& /*rng*/, std::function<void()> /*report*/ = nullptr, size_t /*threads*/= 0) -> void;
37+
auto Run(tf::Executor& /*executor*/, Operon::RandomGenerator&/*rng*/, std::function<void()> /*report*/ = nullptr, /*warmStart*/ bool = false) -> void;
38+
auto Run(Operon::RandomGenerator& /*rng*/, std::function<void()> /*report*/ = nullptr, size_t /*threads*/= 0, /*warmStart*/ bool = false) -> void;
3939
};
4040
} // namespace Operon
4141

include/operon/algorithms/nsga2.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ class OPERON_EXPORT NSGA2 : public GeneticAlgorithmBase {
4848

4949
[[nodiscard]] auto Best() const -> Operon::Span<Individual const> { return { best_.data(), best_.size() }; }
5050

51-
auto Run(tf::Executor& /*executor*/, Operon::RandomGenerator&/*rng*/, std::function<void()> /*report*/ = nullptr) -> void;
52-
auto Run(Operon::RandomGenerator& /*rng*/, std::function<void()> /*report*/ = nullptr, size_t /*threads*/= 0) -> void;
51+
auto Run(tf::Executor& /*executor*/, Operon::RandomGenerator&/*rng*/, std::function<void()> /*report*/ = nullptr, /*warmStart*/ bool = false) -> void;
52+
auto Run(Operon::RandomGenerator& /*rng*/, std::function<void()> /*report*/ = nullptr, size_t /*threads*/= 0, /*warmStart*/ bool = false) -> void;
5353
};
5454
} // namespace Operon
5555

source/algorithms/gp.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
#include "operon/operators/reinserter.hpp" // for ReinserterBase
2121

2222
namespace Operon {
23-
auto GeneticProgrammingAlgorithm::Run(tf::Executor& executor, Operon::RandomGenerator& random, std::function<void()> report) -> void
23+
auto GeneticProgrammingAlgorithm::Run(tf::Executor& executor, Operon::RandomGenerator& random, std::function<void()> report, bool warmStart) -> void
2424
{
25+
Reset();
26+
2527
const auto config = GetConfig();
2628
const auto& treeInit = GetTreeInitializer();
2729
const auto& coeffInit = GetCoefficientInitializer();
@@ -65,10 +67,6 @@ auto GeneticProgrammingAlgorithm::Run(tf::Executor& executor, Operon::RandomGene
6567
tf::Taskflow taskflow;
6668
auto [init, cond, body, back, done] = taskflow.emplace(
6769
[&](tf::Subflow& subflow) {
68-
auto init = subflow.for_each_index(size_t{0}, parents.size(), size_t{1}, [&](size_t i) {
69-
parents[i].Genotype = (*treeInit)(rngs[i]);
70-
(*coeffInit)(rngs[i], parents[i].Genotype);
71-
}).name("initialize population");
7270
auto prepareEval = subflow.emplace([&]() { evaluator->Prepare(parents); }).name("prepare evaluator");
7371
auto eval = subflow.for_each_index(size_t{0}, parents.size(), size_t{1}, [&](size_t i) {
7472
auto id = executor.this_worker_id();
@@ -79,9 +77,17 @@ auto GeneticProgrammingAlgorithm::Run(tf::Executor& executor, Operon::RandomGene
7977
parents[i].Fitness = (*evaluator)(rngs[i], parents[i], slots[id]);
8078
}).name("evaluate population");
8179
auto reportProgress = subflow.emplace([&](){ if (report) { std::invoke(report); } }).name("report progress");
82-
init.precede(prepareEval);
8380
prepareEval.precede(eval);
8481
eval.precede(reportProgress);
82+
83+
if (!(IsFitted() && warmStart)) {
84+
auto init = subflow.for_each_index(size_t{0}, parents.size(), size_t{1}, [&](size_t i) {
85+
parents[i].Genotype = (*treeInit)(rngs[i]);
86+
(*coeffInit)(rngs[i], parents[i].Genotype);
87+
}).name("initialize population");
88+
89+
init.precede(prepareEval);
90+
}
8591
}, // init
8692
stop, // loop condition
8793
[&](tf::Subflow& subflow) {
@@ -111,7 +117,7 @@ auto GeneticProgrammingAlgorithm::Run(tf::Executor& executor, Operon::RandomGene
111117
incrementGeneration.precede(reportProgress);
112118
}, // loop body (evolutionary main loop)
113119
[&]() { return 0; }, // jump back to the next iteration
114-
[&]() { /* all done */ } // work done, report last gen and stop
120+
[&]() { IsFitted() = true; /* all done */ } // work done, report last gen and stop
115121
); // evolutionary loop
116122

117123
init.name("init");
@@ -130,11 +136,11 @@ auto GeneticProgrammingAlgorithm::Run(tf::Executor& executor, Operon::RandomGene
130136
executor.wait_for_all();
131137
}
132138

133-
auto GeneticProgrammingAlgorithm::Run(Operon::RandomGenerator& random, std::function<void()> report, size_t threads) -> void {
139+
auto GeneticProgrammingAlgorithm::Run(Operon::RandomGenerator& random, std::function<void()> report, size_t threads, bool warmStart) -> void {
134140
if (threads == 0) {
135141
threads = std::thread::hardware_concurrency();
136142
}
137143
tf::Executor executor(threads);
138-
Run(executor, random, std::move(report));
144+
Run(executor, random, std::move(report), warmStart);
139145
}
140146
} // namespace Operon

source/algorithms/nsga2.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ auto NSGA2::Sort(Operon::Span<Individual> pop) -> void
9696
std::transform(fronts_.front().begin(), fronts_.front().end(), std::back_inserter(best_), [&](auto i) { return pop[i]; });
9797
}
9898

99-
auto NSGA2::Run(tf::Executor& executor, Operon::RandomGenerator& random, std::function<void()> report) -> void
99+
auto NSGA2::Run(tf::Executor& executor, Operon::RandomGenerator& random, std::function<void()> report, bool warmStart) -> void
100100
{
101+
Reset();
102+
101103
const auto& config = GetConfig();
102104
const auto& treeInit = GetTreeInitializer();
103105
const auto& coeffInit = GetCoefficientInitializer();
@@ -141,10 +143,6 @@ auto NSGA2::Run(tf::Executor& executor, Operon::RandomGenerator& random, std::fu
141143
tf::Taskflow taskflow;
142144
auto [init, cond, body, back, done] = taskflow.emplace(
143145
[&](tf::Subflow& subflow) {
144-
auto init = subflow.for_each_index(size_t{0}, parents.size(), size_t{1}, [&](size_t i) {
145-
parents[i].Genotype = (*treeInit)(rngs[i]);
146-
(*coeffInit)(rngs[i], parents[i].Genotype);
147-
}).name("initialize population");
148146
auto prepareEval = subflow.emplace([&]() { evaluator->Prepare(parents); }).name("prepare evaluator");
149147
auto eval = subflow.for_each_index(size_t{0}, parents.size(), size_t{1}, [&](size_t i) {
150148
// make sure the worker has a large enough buffer
@@ -156,10 +154,18 @@ auto NSGA2::Run(tf::Executor& executor, Operon::RandomGenerator& random, std::fu
156154
auto reportProgress = subflow.emplace([&]() {
157155
if (report) { std::invoke(report); }
158156
}).name("report progress");
159-
init.precede(prepareEval);
160157
prepareEval.precede(eval);
161158
eval.precede(nonDominatedSort);
162159
nonDominatedSort.precede(reportProgress);
160+
161+
if (!(IsFitted() && warmStart)) {
162+
auto init = subflow.for_each_index(size_t{0}, parents.size(), size_t{1}, [&](size_t i) {
163+
parents[i].Genotype = (*treeInit)(rngs[i]);
164+
(*coeffInit)(rngs[i], parents[i].Genotype);
165+
}).name("initialize population");
166+
167+
init.precede(prepareEval);
168+
}
163169
}, // init
164170
stop, // loop condition
165171
[&](tf::Subflow& subflow) {
@@ -189,7 +195,7 @@ auto NSGA2::Run(tf::Executor& executor, Operon::RandomGenerator& random, std::fu
189195
incrementGeneration.precede(reportProgress);
190196
}, // loop body (evolutionary main loop)
191197
[&]() { return 0; }, // jump back to the next iteration
192-
[&]() { /* done nothing to do */ } // work done, report last gen and stop
198+
[&]() { IsFitted() = true; /* all done */ } // work done, report last gen and stop
193199
); // evolutionary loop
194200

195201
init.name("init");
@@ -208,12 +214,12 @@ auto NSGA2::Run(tf::Executor& executor, Operon::RandomGenerator& random, std::fu
208214
executor.wait_for_all();
209215
}
210216

211-
auto NSGA2::Run(Operon::RandomGenerator& random, std::function<void()> report, size_t threads) -> void
217+
auto NSGA2::Run(Operon::RandomGenerator& random, std::function<void()> report, size_t threads, bool warmStart) -> void
212218
{
213219
if (threads == 0U) {
214220
threads = std::thread::hardware_concurrency();
215221
}
216222
tf::Executor executor(threads);
217-
Run(executor, random, std::move(report));
223+
Run(executor, random, std::move(report), warmStart);
218224
}
219225
} // namespace Operon

0 commit comments

Comments
 (0)