LearnTA  0.0.1
membership_oracle.hh
1 
6 #pragma once
7 #include <memory>
8 #include <boost/unordered_map.hpp>
9 
10 #include "sul.hh"
11 #include "timed_word.hh"
12 
13 namespace learnta {
18  public:
19  virtual bool answerQuery(const TimedWord &timedWord) = 0;
20  [[nodiscard]] virtual std::size_t count() const = 0;
21  virtual ~MembershipOracle() = default;
22 
23  virtual std::ostream &printStatistics(std::ostream &stream) const {
24  stream << "Number of membership queries: " << this->count() << "\n";
25 
26  return stream;
27  }
28  };
29 
33  class SULMembershipOracle final : public MembershipOracle {
34  private:
35  std::unique_ptr<SUL> sul;
36  public:
37  explicit SULMembershipOracle(std::unique_ptr<SUL> &&sul) : sul(std::move(sul)) {}
38 
39  bool answerQuery(const learnta::TimedWord &timedWord) override {
40  sul->pre();
41  std::string word = timedWord.getWord();
42  std::vector<double> duration = timedWord.getDurations();
43  bool result = sul->step(duration[0]);
44  for (std::size_t i = 0; i < timedWord.wordSize(); i++) {
45  sul->step(word[i]);
46  result = sul->step(duration[i + 1]);
47  }
48  sul->post();
49 
50  return result;
51  }
52 
53  [[nodiscard]] size_t count() const override {
54  return this->sul->count();
55  }
56  };
57 
61  class MembershipOracleCache final : public MembershipOracle {
62  std::unique_ptr<MembershipOracle> oracle;
63  boost::unordered_map<TimedWord, bool> membershipCache;
64  std::size_t countNoCache = 0;
65 
66  public:
67  explicit MembershipOracleCache(std::unique_ptr<MembershipOracle> &&oracle) : oracle(std::move(oracle)) {}
68 
69  bool answerQuery(const TimedWord &timedWord) override {
70  ++countNoCache;
71  auto it = this->membershipCache.find(timedWord);
72  if (it != membershipCache.end()) {
73  return it->second;
74  }
75  const auto result = this->oracle->answerQuery(timedWord);
76  this->membershipCache[timedWord] = result;
77 
78  return result;
79  }
80 
81  [[nodiscard]] size_t count() const override {
82  return this->oracle->count();
83  }
84 
85  std::ostream &printStatistics(std::ostream &stream) const override {
86  stream << "Number of membership queries: " << countNoCache << "\n";
87  stream << "Number of membership queries (with cache): " << this->count() << "\n";
88 
89  return stream;
90  }
91  };
92 }
Wrapper of a membership oracle to cache the result.
Definition: membership_oracle.hh:61
Interface of a membership oracle.
Definition: membership_oracle.hh:17
Membership oracle defined by an SUL.
Definition: membership_oracle.hh:33
A timed word.
Definition: timed_word.hh:25
std::size_t wordSize() const
Return the number of the actions in this timed word.
Definition: timed_word.hh:150
Definition: experiment_runner.hh:23