Line data Source code
1 : #include <set>
2 :
3 : #include "Module/Stateful/Set/Set.hpp"
4 : #include "Runtime/Sequence/Sequence.hpp"
5 :
6 : namespace spu
7 : {
8 : namespace runtime
9 : {
10 : size_t
11 5919 : Sequence::get_n_threads() const
12 : {
13 5919 : return this->n_threads;
14 : }
15 :
16 : const std::vector<std::vector<runtime::Task*>>&
17 116 : Sequence::get_firsts_tasks() const
18 : {
19 116 : return this->firsts_tasks;
20 : }
21 :
22 : const std::vector<std::vector<runtime::Task*>>&
23 210 : Sequence::get_lasts_tasks() const
24 : {
25 210 : return this->lasts_tasks;
26 : }
27 :
28 : template<class C>
29 : std::vector<C*>
30 762 : Sequence::get_modules(const bool set_modules) const
31 : {
32 762 : std::vector<C*> ret;
33 7750 : for (auto& mm : this->all_modules)
34 56839 : for (auto& m : mm)
35 : {
36 49851 : if (set_modules)
37 : {
38 49851 : auto c = dynamic_cast<module::Set*>(m);
39 49851 : if (c != nullptr)
40 : {
41 48 : auto subret = c->get_sequence().get_modules<C>(set_modules);
42 48 : ret.insert(ret.end(), subret.begin(), subret.end());
43 48 : }
44 : }
45 :
46 49851 : auto c = dynamic_cast<C*>(m);
47 49851 : if (c != nullptr) ret.push_back(c);
48 : }
49 :
50 762 : return ret;
51 0 : }
52 :
53 : template<class C>
54 : std::vector<C*>
55 : Sequence::get_cloned_modules(const C& module_ref) const
56 : {
57 : bool found = false;
58 : size_t mid = 0;
59 : while (mid < this->all_modules[0].size() && !found)
60 : if (dynamic_cast<C*>(this->all_modules[0][mid]) == &module_ref)
61 : found = true;
62 : else
63 : mid++;
64 :
65 : if (!found)
66 : {
67 : std::stringstream message;
68 : message << "'module_ref' can't be found in the sequence.";
69 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
70 : }
71 :
72 : std::vector<C*> cloned_modules(this->all_modules.size());
73 : for (size_t tid = 0; tid < this->all_modules.size(); tid++)
74 : {
75 : auto c = dynamic_cast<C*>(this->all_modules[tid][mid]);
76 : if (c == nullptr)
77 : {
78 : std::stringstream message;
79 : message << "'c' can't be 'nullptr', this should never happen.";
80 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
81 : }
82 : cloned_modules[tid] = c;
83 : }
84 : return cloned_modules;
85 : }
86 :
87 : template<class SS>
88 : inline void
89 : Sequence::_init(tools::Digraph_node<SS>* root)
90 : {
91 : std::stringstream message;
92 : message << "This should never happen.";
93 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
94 : }
95 :
96 : template<>
97 : inline void
98 48 : Sequence::_init(tools::Digraph_node<runtime::Sub_sequence_const>* root)
99 : {
100 48 : this->replicate<runtime::Sub_sequence_const, const module::Module>(root);
101 48 : std::vector<tools::Digraph_node<Sub_sequence_const>*> already_deleted_nodes;
102 48 : this->delete_tree(root, already_deleted_nodes);
103 48 : }
104 :
105 : template<>
106 : inline void
107 464 : Sequence::_init(tools::Digraph_node<runtime::Sub_sequence>* root)
108 : {
109 : std::function<void(tools::Digraph_node<runtime::Sub_sequence>*,
110 : std::vector<tools::Digraph_node<runtime::Sub_sequence>*>&)>
111 464 : remove_useless_nodes;
112 872 : remove_useless_nodes = [&](tools::Digraph_node<runtime::Sub_sequence>* node,
113 : std::vector<tools::Digraph_node<runtime::Sub_sequence>*>& already_parsed_nodes)
114 : {
115 1744 : if (node != nullptr &&
116 1744 : std::find(already_parsed_nodes.begin(), already_parsed_nodes.end(), node) == already_parsed_nodes.end())
117 : {
118 822 : auto node_contents = node->get_c();
119 :
120 822 : if (node->get_parents().size() == 1 && node->get_children().size() == 1 && node_contents->tasks.size() == 0)
121 : {
122 56 : auto parent = node->get_parents().size() ? node->get_parents()[0] : nullptr;
123 56 : auto child = node->get_children().size() ? node->get_children()[0] : nullptr;
124 :
125 56 : auto child_pos = -1;
126 56 : if (parent != nullptr)
127 : {
128 56 : child_pos = node->get_child_pos(*parent);
129 56 : if (child_pos == -1)
130 : {
131 0 : std::stringstream message;
132 0 : message << "'child_pos' should be different from '-1'.";
133 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
134 0 : }
135 :
136 56 : if (!parent->cut_child((size_t)child_pos))
137 : {
138 0 : std::stringstream message;
139 0 : message << "'parent->cut_child(child_pos)' should return true ('child_pos' = " << child_pos
140 0 : << ").";
141 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
142 0 : }
143 : }
144 :
145 56 : auto parent_pos = -1;
146 56 : if (child != nullptr)
147 : {
148 56 : parent_pos = node->get_parent_pos(*child);
149 56 : if (parent_pos == -1)
150 : {
151 0 : std::stringstream message;
152 0 : message << "'parent_pos' should be different from '-1'.";
153 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
154 0 : }
155 :
156 56 : if (!child->cut_parent((size_t)parent_pos))
157 : {
158 0 : std::stringstream message;
159 0 : message << "'child->cut_parent(parent_pos)' should return true ('parent_pos' = " << parent_pos
160 0 : << ").";
161 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
162 0 : }
163 : }
164 :
165 56 : if (node == root) root = child;
166 :
167 56 : delete node_contents;
168 56 : delete node;
169 :
170 56 : if (child != nullptr && parent != nullptr)
171 : {
172 56 : parent->add_child(child, child_pos);
173 56 : child->add_parent(parent, parent_pos);
174 : }
175 :
176 56 : node = child;
177 56 : if (child != nullptr) node_contents = node->get_c();
178 : }
179 :
180 1644 : if (node != nullptr &&
181 1644 : std::find(already_parsed_nodes.begin(), already_parsed_nodes.end(), node) == already_parsed_nodes.end())
182 : {
183 788 : already_parsed_nodes.push_back(node);
184 :
185 788 : if (node->get_parents().size())
186 : {
187 324 : size_t min_depth = node->get_parents()[0]->get_depth();
188 408 : for (size_t f = 1; f < node->get_parents().size(); f++)
189 84 : min_depth = std::min(min_depth, node->get_parents()[f]->get_depth());
190 324 : node->set_depth(min_depth + 1);
191 : }
192 : else
193 464 : node->set_depth(0);
194 :
195 1196 : for (auto c : node->get_children())
196 408 : remove_useless_nodes(c, already_parsed_nodes);
197 : }
198 : }
199 1336 : };
200 464 : std::vector<tools::Digraph_node<runtime::Sub_sequence>*> already_parsed_nodes1;
201 464 : remove_useless_nodes(root, already_parsed_nodes1);
202 :
203 : std::function<void(tools::Digraph_node<runtime::Sub_sequence>*,
204 : size_t&,
205 : std::vector<tools::Digraph_node<runtime::Sub_sequence>*>&,
206 : std::map<tools::Digraph_node<runtime::Sub_sequence>*, std::pair<size_t, size_t>>&)>
207 464 : init_ss_ids_rec;
208 : init_ss_ids_rec =
209 788 : [&](tools::Digraph_node<runtime::Sub_sequence>* node,
210 : size_t& ssid,
211 : std::vector<tools::Digraph_node<runtime::Sub_sequence>*>& already_parsed_nodes,
212 : std::map<tools::Digraph_node<runtime::Sub_sequence>*, std::pair<size_t, size_t>>& select_parents)
213 : {
214 1576 : if (node != nullptr &&
215 1576 : std::find(already_parsed_nodes.begin(), already_parsed_nodes.end(), node) == already_parsed_nodes.end())
216 : {
217 788 : already_parsed_nodes.push_back(node);
218 788 : auto node_contents = node->get_c();
219 788 : node_contents->id = ssid++;
220 :
221 1196 : for (auto c : node->get_children())
222 408 : switch (c->get_c()->type)
223 : {
224 158 : case subseq_t::SELECT:
225 : {
226 158 : if (select_parents.find(c) == select_parents.end())
227 : {
228 74 : size_t n_parents = 0;
229 232 : for (auto f : c->get_parents())
230 : {
231 158 : if (f->get_depth() < c->get_depth()) n_parents++;
232 : }
233 74 : select_parents[c] = std::make_pair(0, n_parents);
234 : }
235 158 : std::get<0>(select_parents[c]) += 1;
236 :
237 158 : if (std::get<0>(select_parents[c]) == std::get<1>(select_parents[c]))
238 74 : init_ss_ids_rec(c, ssid, already_parsed_nodes, select_parents);
239 158 : break;
240 : }
241 74 : case subseq_t::COMMUTE:
242 : {
243 74 : init_ss_ids_rec(c, ssid, already_parsed_nodes, select_parents);
244 74 : break;
245 : }
246 176 : case subseq_t::STD:
247 : {
248 176 : init_ss_ids_rec(c, ssid, already_parsed_nodes, select_parents);
249 176 : break;
250 : }
251 0 : default:
252 : {
253 0 : std::stringstream message;
254 0 : message << "This should never happen.";
255 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
256 0 : }
257 : };
258 : }
259 1252 : };
260 464 : std::map<tools::Digraph_node<runtime::Sub_sequence>*, std::pair<size_t, size_t>> select_parents;
261 464 : std::vector<tools::Digraph_node<runtime::Sub_sequence>*> already_parsed_nodes2;
262 464 : size_t ssid = 0;
263 464 : init_ss_ids_rec(root, ssid, already_parsed_nodes2, select_parents);
264 :
265 464 : this->sequences[0] = root;
266 :
267 464 : std::set<module::Module*> modules_set;
268 : std::function<void(const tools::Digraph_node<runtime::Sub_sequence>*,
269 : std::vector<const tools::Digraph_node<runtime::Sub_sequence>*>&)>
270 464 : collect_modules_list;
271 872 : collect_modules_list = [&](const tools::Digraph_node<runtime::Sub_sequence>* node,
272 : std::vector<const tools::Digraph_node<runtime::Sub_sequence>*>& already_parsed_nodes)
273 : {
274 1744 : if (node != nullptr &&
275 1744 : std::find(already_parsed_nodes.begin(), already_parsed_nodes.end(), node) == already_parsed_nodes.end())
276 : {
277 788 : already_parsed_nodes.push_back(node);
278 788 : if (node->get_c())
279 3280 : for (auto ta : node->get_c()->tasks)
280 2492 : modules_set.insert(&ta->get_module());
281 1196 : for (auto c : node->get_children())
282 408 : collect_modules_list(c, already_parsed_nodes);
283 : }
284 1336 : };
285 464 : std::vector<const tools::Digraph_node<runtime::Sub_sequence>*> already_parsed_nodes3;
286 464 : collect_modules_list(root, already_parsed_nodes3);
287 :
288 2871 : for (auto m : modules_set)
289 2407 : this->all_modules[0].push_back(m);
290 :
291 464 : this->replicate<runtime::Sub_sequence, module::Module>(root);
292 464 : }
293 :
294 : size_t
295 840 : Sequence::get_n_frames() const
296 : {
297 840 : const auto n_frames = this->all_modules[0][0]->get_n_frames();
298 :
299 10340 : for (auto& mm : this->all_modules)
300 80662 : for (auto& m : mm)
301 71162 : if (m->get_n_frames() != n_frames)
302 : {
303 0 : std::stringstream message;
304 0 : message << "All the modules do not have the same 'n_frames' value ('m->get_n_frames()' = "
305 0 : << m->get_n_frames() << ", 'n_frames' = " << n_frames << ").";
306 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
307 0 : }
308 :
309 840 : return n_frames;
310 : }
311 :
312 : }
313 : }
|