From d80e0dbd20bb8597afe92f770e39d20440557d1f Mon Sep 17 00:00:00 2001 From: Julian Blake Kongslie Date: Sat, 25 Jun 2022 10:16:12 -0700 Subject: Demo for a coroutine-based step evaluator. --- aisa/aisa.cpp | 5 --- aisa/aisa.h | 73 ++++++++++++++++++++++++++++++++- aisa/coroutine.h | 120 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ main.cpp | 52 +++++++++++++++++++++++- 4 files changed, 243 insertions(+), 7 deletions(-) create mode 100644 aisa/coroutine.h diff --git a/aisa/aisa.cpp b/aisa/aisa.cpp index 0c806de..97a3b17 100644 --- a/aisa/aisa.cpp +++ b/aisa/aisa.cpp @@ -4,9 +4,4 @@ namespace aisa { - void do_something() - { - std::cout << "Hello, world!\n"; - } - } diff --git a/aisa/aisa.h b/aisa/aisa.h index b570d6b..8cb302e 100644 --- a/aisa/aisa.h +++ b/aisa/aisa.h @@ -1,7 +1,78 @@ #pragma once +#include +#include +#include +#include +#include +#include + +#include "aisa/coroutine.h" + namespace aisa { - void do_something(); + using regnum_t = std::uint_fast64_t; + using regval_t = std::uint64_t; + + template struct EvalState { + CRTP & crtp() noexcept { return static_cast(*this); } + task async_load_reg(regnum_t rn) + { + while (true) { + if (auto rv = crtp().load_reg(rn); rv.has_value()) + co_return *rv; + co_await suspend(); + } + } + }; + + struct Step { + const std::optional> predicate; + const std::vector source_regs; + const std::vector destination_regs; + + std::optional predicate_reg() const + { + if (predicate.has_value()) + return predicate->first; + return {}; + } + + std::optional expected_predicate_val() const + { + if (predicate.has_value()) + return predicate->second; + return {}; + } + + template task evaluate(State &state) const + { + if (predicate.has_value()) { + std::cout << "checking predicate...\n"; + std::cout << "\texpect " << predicate->second << "\n"; + regval_t pval = co_await state.async_load_reg(predicate->first); + std::cout << "\tgot " << pval << "\n"; + if (pval != predicate->second) { + std::cout << "\tpredicate skipped\n"; + co_return; + } else { + std::cout << "\tpredicate not skipped\n"; + } + } + std::cout << "reading sources...\n"; + std::vector source_vals; + source_vals.reserve(source_regs.size()); + for (unsigned int i = 0; i < source_regs.size(); ++i) { + std::cout << "\tgetting source " << i << "...\n"; + source_vals.emplace_back(co_await state.async_load_reg(source_regs[i])); + std::cout << "\t\tgot " << source_vals.back() << "\n"; + } + std::cout << "sources:"; + for (unsigned int i = 0; i < source_regs.size(); ++i) + std::cout << " " << source_regs[i] << "=" << source_vals[i]; + std::cout << "\n"; + std::cout << "done with evaluate\n"; + } + }; } diff --git a/aisa/coroutine.h b/aisa/coroutine.h new file mode 100644 index 0000000..40a6982 --- /dev/null +++ b/aisa/coroutine.h @@ -0,0 +1,120 @@ +#pragma once + +#include + +namespace aisa { + + inline auto suspend() noexcept { return std::suspend_always{}; } + + template struct promise; + + template struct task : public std::coroutine_handle> { + using handle = std::coroutine_handle>; + using promise_type = struct promise; + bool await_ready() const noexcept { return handle::done(); } + result_t await_resume() const noexcept; + template void await_suspend(std::coroutine_handle> h) const noexcept; + std::optional operator()() noexcept; + }; + + template<> struct task : public std::coroutine_handle> { + using handle = std::coroutine_handle>; + using promise_type = struct promise; + bool await_ready() const noexcept { return handle::done(); } + void await_resume() const noexcept; + template void await_suspend(std::coroutine_handle> h) const noexcept; + bool operator()() noexcept; + }; + + template struct promise { + std::coroutine_handle<> precursor; + std::optional result; + promise() = default; + promise(const promise &) = delete; + task get_return_object() noexcept { return task{std::coroutine_handle>::from_promise(*this)}; } + std::suspend_never initial_suspend() const noexcept { return {}; } + std::suspend_always final_suspend() const noexcept { return {}; } + void unhandled_exception() { } + void return_value(result_t x) noexcept { result = std::move(x); } + }; + + template<> struct promise { + std::coroutine_handle<> precursor; + promise() = default; + promise(const promise &) = delete; + task get_return_object() noexcept { return task{std::coroutine_handle>::from_promise(*this)}; } + std::suspend_never initial_suspend() const noexcept { return {}; } + std::suspend_always final_suspend() const noexcept { return {}; } + void unhandled_exception() { } + void return_void() noexcept { } + }; + + template result_t task::await_resume() const noexcept + { + auto x = std::move(handle::promise().result.value()); + handle::destroy(); + return std::move(x); + } + + template template void task::await_suspend(std::coroutine_handle> h) const noexcept + { + h.promise().precursor = *this; + } + + template std::optional task::operator()() noexcept + { + if (!handle::operator bool()) + return {}; + if (!handle::done()) { + auto &precursor = handle::promise().precursor; + if (precursor) { + if (!precursor.done()) + precursor.resume(); + if (precursor.done()) + precursor = nullptr; + } + if (!precursor) + handle::resume(); + } + if (handle::done()) { + auto x = await_resume(); + handle::operator=(nullptr); + return std::move(x); + } + return {}; + } + + inline void task::await_resume() const noexcept + { + handle::destroy(); + } + + template void task::await_suspend(std::coroutine_handle> h) const noexcept + { + h.promise().precursor = *this; + } + + inline bool task::operator()() noexcept + { + if (!handle::operator bool()) + return true; + if (!handle::done()) { + auto &precursor = handle::promise().precursor; + if (precursor) { + if (!precursor.done()) + precursor.resume(); + if (precursor.done()) + precursor = nullptr; + } + if (!precursor) + handle::resume(); + } + if (handle::done()) { + await_resume(); + handle::operator=(nullptr); + return true; + } + return false; + } + +} diff --git a/main.cpp b/main.cpp index 6b7b0f4..0240bb5 100644 --- a/main.cpp +++ b/main.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include "aisa/aisa.h" #include "git-tag.h" @@ -6,6 +8,54 @@ int main(int argc, const char *argv[]) { std::cout << "Version " << GIT_TAG << "\n"; - aisa::do_something(); + + aisa::Step step; + const_cast> &>(step.predicate) = std::optional(std::make_pair(123, 456)); + const_cast &>(step.source_regs).emplace_back(12); + const_cast &>(step.source_regs).emplace_back(34); + const_cast &>(step.source_regs).emplace_back(56); + + struct State : public aisa::EvalState { + std::map regs; + + std::optional load_reg(aisa::regnum_t rn) + { + std::cout << "state.load_reg(" << rn << ") = "; + if (auto x = regs.find(rn); x != regs.end()) { + std::cout << x->second << "\n"; + return x->second; + } + std::cout << "(not available)\n"; + return {}; + } + } state; + + auto t = state.async_load_reg(999); + t(); + t(); + t(); + std::cout << "set regs[999] = 54321\n"; state.regs[999] = 54321; + std::optional result; + while (!result.has_value()) + result = t(); + std::cout << "result = " << *result << "\n"; + + std::cout << "\n\n\n"; + + auto w = step.evaluate(state); + w(); + w(); + w(); + std::cout << "set predicate (valid)\n"; state.regs[step.predicate->first] = step.predicate->second; + w(); + w(); + w(); + std::cout << "set regs (all)\n"; + for (int i = 10; i < 100; ++i) + state.regs[i] = 1000 + i; + for (bool done = false; !done; done = w()) + ; + std::cout << "huzzah!\n"; + return 0; } -- cgit v1.2.3