LearnTA  0.0.1
zone_automaton.hh
1 #pragma once
2 
3 #include <stack>
4 #include <unordered_set>
5 #include <utility>
6 
7 #include "timed_automaton.hh"
8 #include "zone.hh"
9 #include "timed_word.hh"
10 #include "symbolic_run.hh"
11 #include "zone_automaton_state.hh"
12 
13 namespace learnta {
17  struct ZoneAutomaton : public Automaton<ZAState> {
23  [[nodiscard]] std::optional<TimedWord> sample() const {
24  std::vector<SymbolicRun> currentStates;
25  currentStates.reserve(initialStates.size());
26  std::transform(initialStates.begin(), initialStates.end(), std::back_inserter(currentStates),
27  [](const auto &initialState) {
28  return SymbolicRun{initialState};
29  });
30  std::unordered_set<std::shared_ptr<ZAState>> visited = {initialStates.begin(), initialStates.end()};
31 
32  while (!currentStates.empty()) {
33  std::vector<SymbolicRun> nextStates;
34  for (const auto &run: currentStates) {
35  if (run.back()->isMatch) {
36  // run is a positive run
37  auto wordOpt = run.reconstructWord();
38  if (wordOpt) {
39  return wordOpt;
40  }
41  }
42  for (int action = 0; action < CHAR_MAX; ++action) {
43  const auto &edges = run.back()->next[action];
44  for (const auto &edge: edges) {
45  auto transition = edge.first;
46  auto target = edge.second.lock();
47  if (target && visited.find(target) == visited.end()) {
48  // We have not visited the state
49  auto newRun = run;
50  newRun.push_back(transition, static_cast<char>(action), target);
51  nextStates.push_back(newRun);
52  visited.insert(target);
53  }
54  }
55  }
56  }
57  currentStates = std::move(nextStates);
58  }
59 
60  return std::nullopt;
61  }
62 
63  std::optional<TimedWord> sampleMemo;
64  std::optional<TimedWord> sampleWithMemo() {
65  if (sampleMemo) {
66  return sampleMemo;
67  } else {
68  sampleMemo = sample();
69  return sampleMemo;
70  }
71  }
72 
77  std::unordered_map<std::shared_ptr<ZAState>, std::unordered_set<std::shared_ptr<ZAState>>> backwardEdges;
78  for (const auto &state: this->states) {
79  for (const auto &transitions: state->next) {
80  for (const auto &[transition, target]: transitions) {
81  auto it = backwardEdges.find(target.lock());
82  if (it == backwardEdges.end()) {
83  backwardEdges[target.lock()] = {state};
84  } else {
85  backwardEdges.at(target.lock()).insert(state);
86  }
87  }
88  }
89  }
90 
91  // The states reachable to an accepting state
92  std::unordered_set<std::shared_ptr<ZAState>> liveStates;
93  std::queue<std::shared_ptr<ZAState>> newLiveStates;
94  for (const auto &state: this->states) {
95  if (state->isMatch) {
96  liveStates.insert(state);
97  newLiveStates.push(state);
98  }
99  }
100  while (!newLiveStates.empty()) {
101  const auto newLiveState = newLiveStates.front();
102  newLiveStates.pop();
103  if (backwardEdges.find(newLiveState) != backwardEdges.end()) {
104  for (const auto &backwardState: backwardEdges.at(newLiveState)) {
105  if (liveStates.find(backwardState) == liveStates.end()) {
106  liveStates.insert(backwardState);
107  newLiveStates.push(backwardState);
108  }
109  }
110  }
111  }
112 
113  if (liveStates.size() != this->stateSize()) {
114  // Remove dead states if exists
115  BOOST_LOG_TRIVIAL(info) << "There are " << this->stateSize() - liveStates.size() << " dead states in the zone graph";
116  this->states.erase(std::remove_if(this->states.begin(), this->states.end(), [&](const auto &state) {
117  return liveStates.find(state) == liveStates.end();
118  }), this->states.end());
119  this->initialStates.erase(
120  std::remove_if(this->initialStates.begin(), this->initialStates.end(), [&](const auto &state) {
121  return liveStates.find(state) == liveStates.end();
122  }), this->initialStates.end());
123  for (const auto &state: this->states) {
124  for (auto &transitions: state->next) {
125  transitions.erase(std::remove_if(transitions.begin(), transitions.end(), [&](const auto &pair) {
126  const auto &[transition, target] = pair;
127  return liveStates.find(target.lock()) == liveStates.end();
128  }), transitions.end());
129  }
130  }
131  }
132  }
133  };
134 }
Definition: experiment_runner.hh:23
An automaton.
Definition: common_types.hh:18
std::vector< std::shared_ptr< ZAState > > initialStates
The initial states of this automaton.
Definition: common_types.hh:24
A Zone automaton.
Definition: zone_automaton.hh:17
std::optional< TimedWord > sample() const
Sample a timed word in this zone automaton.
Definition: zone_automaton.hh:23
void removeDeadStates()
Remove dead states, i.e., unreachable to any of the accepting states.
Definition: zone_automaton.hh:76