Line data Source code
1 : #include <algorithm>
2 : #include <iomanip>
3 : #include <ios>
4 : #include <iostream>
5 : #include <rang.hpp>
6 : #include <sstream>
7 : #include <typeinfo>
8 :
9 : #include "Module/Module.hpp"
10 : #include "Module/Stateful/Set/Set.hpp"
11 : #include "Runtime/Socket/Socket.hpp"
12 : #include "Runtime/Task/Task.hpp"
13 : #include "Tools/Exception/exception.hpp"
14 :
15 : using namespace spu;
16 : using namespace spu::runtime;
17 :
18 3398 : Task::Task(module::Module& module,
19 : const std::string& name,
20 : const bool stats,
21 : const bool fast,
22 : const bool debug,
23 3398 : const bool outbuffers_allocated)
24 3398 : : module(&module)
25 3398 : , name(name)
26 3398 : , stats(stats)
27 3398 : , fast(fast)
28 3398 : , debug(debug)
29 3398 : , outbuffers_allocated(outbuffers_allocated)
30 3398 : , debug_hex(false)
31 3398 : , replicable(module.is_clonable())
32 3398 : , debug_limit(-1)
33 3398 : , debug_precision(2)
34 3398 : , debug_frame_max(-1)
35 3398 : , codelet(
36 0 : [](module::Module& /*m*/, Task& /*t*/, const size_t /*frame_id*/) -> int
37 : {
38 0 : throw tools::unimplemented_error(__FILE__, __LINE__, __func__);
39 : return 0;
40 : })
41 3398 : , n_input_sockets(0)
42 3398 : , n_output_sockets(0)
43 3398 : , n_fwd_sockets(0)
44 3398 : , status(module.get_n_waves())
45 3398 : , n_calls(0)
46 3398 : , duration_total(std::chrono::nanoseconds(0))
47 3398 : , duration_min(std::chrono::nanoseconds(0))
48 3398 : , duration_max(std::chrono::nanoseconds(0))
49 6796 : , last_input_socket(nullptr)
50 : {
51 3398 : }
52 :
53 : Socket&
54 0 : Task::operator[](const std::string& sck_name)
55 : {
56 0 : std::string s_name = sck_name;
57 0 : s_name.erase(remove(s_name.begin(), s_name.end(), ' '), s_name.end());
58 :
59 0 : auto it = find_if(this->sockets.begin(),
60 : this->sockets.end(),
61 0 : [s_name](std::shared_ptr<runtime::Socket> s) { return s->get_name() == s_name; });
62 :
63 0 : if (it == this->sockets.end())
64 : {
65 0 : std::stringstream message;
66 0 : message << "runtime::Socket '" << s_name << "' not found for task '" << this->get_name() << "'.";
67 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
68 0 : }
69 :
70 0 : return *it->get();
71 0 : }
72 :
73 : void
74 80207 : Task::set_stats(const bool stats)
75 : {
76 80207 : this->stats = stats;
77 80207 : }
78 :
79 : void
80 85375 : Task::set_fast(const bool fast)
81 : {
82 85375 : this->fast = fast;
83 309841 : for (size_t i = 0; i < sockets.size(); i++)
84 224466 : sockets[i]->set_fast(this->fast);
85 85375 : }
86 :
87 : void
88 80115 : Task::set_debug(const bool debug)
89 : {
90 80115 : this->debug = debug;
91 80115 : }
92 :
93 : void
94 0 : Task::set_debug_hex(const bool debug_hex)
95 : {
96 0 : this->debug_hex = debug_hex;
97 0 : }
98 :
99 : void
100 80115 : Task::set_debug_limit(const uint32_t limit)
101 : {
102 80115 : this->debug_limit = (int32_t)limit;
103 80115 : }
104 :
105 : void
106 0 : Task::set_debug_precision(const uint8_t prec)
107 : {
108 0 : this->debug_precision = prec;
109 0 : }
110 :
111 : void
112 0 : Task::set_debug_frame_max(const uint32_t limit)
113 : {
114 0 : this->debug_frame_max = limit;
115 0 : }
116 :
117 : // trick to compile on the GNU compiler version 4 (where 'std::hexfloat' is unavailable)
118 : #if !defined(__clang__) && !defined(__llvm__) && defined(__GNUC__) && defined(__cplusplus) && __GNUC__ < 5
119 : namespace std
120 : {
121 : class Hexfloat
122 : {
123 : public:
124 : void message(std::ostream& os) const { os << " /!\\ 'std::hexfloat' is not supported by this compiler. /!\\ "; }
125 : };
126 : Hexfloat hexfloat;
127 : }
128 : std::ostream&
129 : operator<<(std::ostream& os, const std::Hexfloat& obj)
130 : {
131 : obj.message(os);
132 : return os;
133 : }
134 : #endif
135 :
136 : template<typename T>
137 : static inline void
138 403 : display_data(const T* data,
139 : const size_t fra_size,
140 : const size_t n_fra,
141 : const size_t n_fra_per_w,
142 : const size_t limit,
143 : const size_t max_frame,
144 : const uint8_t p,
145 : const uint8_t n_spaces,
146 : const bool hex)
147 : {
148 403 : constexpr bool is_float_type = std::is_same<float, T>::value || std::is_same<double, T>::value;
149 :
150 403 : std::ios::fmtflags f(std::cout.flags());
151 403 : if (hex)
152 : {
153 : if (is_float_type)
154 0 : std::cout << std::hexfloat << std::hex;
155 : else
156 0 : std::cout << std::hex;
157 : }
158 : else
159 403 : std::cout << std::fixed << std::setprecision(p) << std::dec;
160 :
161 403 : if (n_fra == 1 && max_frame != 0)
162 : {
163 5635 : for (size_t i = 0; i < limit; i++)
164 : {
165 5232 : if (hex)
166 0 : std::cout << (!is_float_type ? "0x" : "") << +data[i] << (i < limit - 1 ? ", " : "");
167 : else
168 5232 : std::cout << std::setw(p + 3) << +data[i] << (i < limit - 1 ? ", " : "");
169 : }
170 403 : std::cout << (limit < fra_size ? ", ..." : "");
171 403 : }
172 : else
173 : {
174 0 : std::string spaces = "#";
175 0 : for (uint8_t s = 0; s < n_spaces - 1; s++)
176 0 : spaces += " ";
177 :
178 0 : auto n_digits_dec = [](size_t f) -> size_t
179 : {
180 0 : size_t count = 0;
181 0 : while (f && ++count)
182 0 : f /= 10;
183 0 : return count;
184 : };
185 :
186 0 : const auto n_digits = n_digits_dec(max_frame);
187 0 : auto ftos = [&n_digits_dec, &n_digits](size_t f) -> std::string
188 : {
189 0 : auto n_zero = n_digits - n_digits_dec(f);
190 0 : std::string f_str = "";
191 0 : for (size_t z = 0; z < n_zero; z++)
192 0 : f_str += "0";
193 0 : f_str += std::to_string(f);
194 0 : return f_str;
195 0 : };
196 :
197 0 : const auto n_digits_w = n_digits_dec((max_frame / n_fra_per_w) == 0 ? 1 : (max_frame / n_fra_per_w));
198 0 : auto wtos = [&n_digits_dec, &n_digits_w](size_t w) -> std::string
199 : {
200 0 : auto n_zero = n_digits_w - n_digits_dec(w);
201 0 : std::string f_str = "";
202 0 : for (size_t z = 0; z < n_zero; z++)
203 0 : f_str += "0";
204 0 : f_str += std::to_string(w);
205 0 : return f_str;
206 0 : };
207 :
208 0 : for (size_t f = 0; f < max_frame; f++)
209 : {
210 0 : const auto w = f / n_fra_per_w;
211 0 : std::cout << (f >= 1 ? spaces : "") << rang::style::bold << rang::fg::gray << "f" << ftos(f + 1) << "_w"
212 0 : << wtos(w + 1) << rang::style::reset << "(";
213 :
214 0 : for (size_t i = 0; i < limit; i++)
215 : {
216 0 : if (hex)
217 0 : std::cout << (!is_float_type ? "0x" : "") << +data[f * fra_size + i] << (i < limit - 1 ? ", " : "");
218 : else
219 0 : std::cout << std::setw(p + 3) << +data[f * fra_size + i] << (i < limit - 1 ? ", " : "");
220 : }
221 0 : std::cout << (limit < fra_size ? ", ..." : "") << ")" << (f < n_fra - 1 ? ", \n" : "");
222 : }
223 :
224 0 : if (max_frame < n_fra)
225 : {
226 0 : const auto w1 = max_frame / n_fra_per_w;
227 0 : const auto w2 = n_fra / n_fra_per_w;
228 0 : std::cout << (max_frame >= 1 ? spaces : "") << rang::style::bold << rang::fg::gray << "f"
229 0 : << std::setw(n_digits) << max_frame + 1 << "_w" << std::setw(n_digits_w) << w1 + 1 << "->"
230 0 : << "f" << std::setw(n_digits) << n_fra << "_w" << std::setw(n_digits_w) << w2 + 1 << ":"
231 0 : << rang::style::reset << "(...)";
232 : }
233 0 : }
234 :
235 403 : std::cout.flags(f);
236 403 : }
237 :
238 : void
239 3293723 : Task::_exec(const int frame_id, const bool managed_memory)
240 : {
241 3293723 : const auto n_frames = this->get_module().get_n_frames();
242 3270494 : const auto n_frames_per_wave = this->get_module().get_n_frames_per_wave();
243 3231806 : const auto n_waves = this->get_module().get_n_waves();
244 3197840 : const auto n_frames_per_wave_rest = this->get_module().get_n_frames_per_wave_rest();
245 :
246 : // do not use 'this->status' because the dataptr can have been changed by the 'tools::Sequence' when using the no
247 : // copy mode
248 3172115 : int* status = this->sockets.back()->get_dataptr<int>();
249 7747026 : for (size_t w = 0; w < n_waves; w++)
250 4746314 : status[w] = (int)status_t::UNKNOWN;
251 :
252 3000712 : if ((managed_memory == false && frame_id >= 0) || (frame_id == -1 && n_frames_per_wave == n_frames) ||
253 161531 : (frame_id == 0 && n_frames_per_wave == 1) || (frame_id == 0 && n_waves > 1) ||
254 0 : (frame_id == 0 && n_frames_per_wave_rest == 0))
255 : {
256 2840338 : const auto real_frame_id = frame_id == -1 ? 0 : frame_id;
257 2840338 : const size_t w = (real_frame_id % n_frames) / n_frames_per_wave;
258 2840338 : status[w] = this->codelet(*this->module, *this, real_frame_id);
259 3127439 : }
260 : else
261 : {
262 : // save the initial dataptr of the sockets
263 454663 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
264 290554 : sockets_dataptr_init[sid] = (int8_t*)this->sockets[sid]->_get_dataptr();
265 :
266 162685 : if (frame_id > 0 && managed_memory == true && n_frames_per_wave > 1)
267 : {
268 0 : const size_t w = (frame_id % n_frames) / n_frames_per_wave;
269 0 : const size_t w_pos = frame_id % n_frames_per_wave;
270 :
271 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
272 : {
273 0 : if (sockets[sid]->get_type() == socket_t::SIN || sockets[sid]->get_type() == socket_t::SFWD)
274 0 : std::copy(
275 0 : sockets_dataptr_init[sid] + ((frame_id % n_frames) + 0) * sockets_databytes_per_frame[sid],
276 0 : sockets_dataptr_init[sid] + ((frame_id % n_frames) + 1) * sockets_databytes_per_frame[sid],
277 0 : sockets_data[sid].begin() + w_pos * sockets_databytes_per_frame[sid]);
278 0 : this->sockets[sid]->dataptr = (void*)sockets_data[sid].data();
279 : }
280 :
281 0 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
282 :
283 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
284 0 : if (sockets[sid]->get_type() == socket_t::SOUT || sockets[sid]->get_type() == socket_t::SFWD)
285 0 : std::copy(sockets_data[sid].begin() + (w_pos + 0) * sockets_databytes_per_frame[sid],
286 0 : sockets_data[sid].begin() + (w_pos + 1) * sockets_databytes_per_frame[sid],
287 0 : sockets_dataptr_init[sid] + (frame_id % n_frames) * sockets_databytes_per_frame[sid]);
288 0 : }
289 : else // if (frame_id <= 0 || n_frames_per_wave == 1)
290 : {
291 162685 : const size_t w_start = (frame_id < 0) ? 0 : frame_id % n_waves;
292 162685 : const size_t w_stop = (frame_id < 0) ? n_waves : w_start + 1;
293 :
294 162685 : size_t w = 0;
295 162685 : auto exec_status = status_t::SUCCESS;
296 1657174 : for (w = w_start; w < w_stop - 1 && exec_status != status_t::FAILURE_STOP; w++)
297 : {
298 4188356 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
299 2644960 : this->sockets[sid]->dataptr =
300 5326060 : (void*)(sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
301 :
302 1246176 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
303 1494489 : exec_status = (status_t)status[w];
304 : }
305 :
306 152461 : if (exec_status != status_t::FAILURE_STOP)
307 : {
308 162623 : if (n_frames_per_wave_rest == 0)
309 : {
310 455419 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
311 292390 : this->sockets[sid]->dataptr =
312 585086 : (void*)(sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
313 :
314 156643 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
315 : }
316 : else
317 : {
318 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
319 : {
320 0 : if (sockets[sid]->get_type() == socket_t::SIN || sockets[sid]->get_type() == socket_t::SFWD)
321 0 : std::copy(sockets_dataptr_init[sid] +
322 0 : w * n_frames_per_wave * sockets_databytes_per_frame[sid],
323 0 : sockets_dataptr_init[sid] + n_frames * sockets_databytes_per_frame[sid],
324 0 : sockets_data[sid].begin());
325 0 : this->sockets[sid]->dataptr = (void*)sockets_data[sid].data();
326 : }
327 :
328 0 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
329 :
330 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
331 0 : if (sockets[sid]->get_type() == socket_t::SOUT || sockets[sid]->get_type() == socket_t::SFWD)
332 0 : std::copy(
333 0 : sockets_data[sid].begin(),
334 0 : sockets_data[sid].begin() + n_frames_per_wave_rest * sockets_databytes_per_frame[sid],
335 0 : sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
336 : }
337 : }
338 : }
339 :
340 : // restore the initial dataptr of the sockets
341 453875 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
342 291620 : this->sockets[sid]->dataptr = (void*)sockets_dataptr_init[sid];
343 : }
344 3284343 : }
345 :
346 : const std::vector<int>&
347 3222601 : Task::exec(const int frame_id, const bool managed_memory)
348 : {
349 : #ifndef SPU_FAST
350 3222601 : if (this->is_fast() && !this->is_debug() && !this->is_stats())
351 : {
352 : #endif
353 3236032 : this->_exec(frame_id, managed_memory);
354 3259305 : this->n_calls++;
355 3259305 : return this->get_status();
356 : #ifndef SPU_FAST
357 : }
358 :
359 7163 : if (frame_id != -1 && (size_t)frame_id >= this->get_module().get_n_frames())
360 : {
361 0 : std::stringstream message;
362 0 : message << "'frame_id' has to be equal to '-1' or to be smaller than 'n_frames' ('frame_id' = " << frame_id
363 0 : << ", 'n_frames' = " << this->get_module().get_n_frames() << ").";
364 0 : throw tools::length_error(__FILE__, __LINE__, __func__, message.str());
365 0 : }
366 :
367 8423 : if (this->is_fast() || this->can_exec())
368 : {
369 8407 : size_t max_n_chars = 0;
370 8407 : if (this->is_debug())
371 : {
372 335 : auto n_fra = this->module->get_n_frames();
373 335 : auto n_fra_per_w = this->module->get_n_frames_per_wave();
374 :
375 : std::string module_name =
376 335 : module->get_custom_name().empty() ? module->get_name() : module->get_custom_name();
377 :
378 335 : std::cout << "# ";
379 335 : std::cout << rang::style::bold << rang::fg::green << module_name << rang::style::reset
380 335 : << "::" << rang::style::bold << rang::fg::magenta << get_name() << rang::style::reset << "(";
381 740 : for (auto i = 0; i < (int)sockets.size() - 1; i++)
382 : {
383 405 : auto& s = *sockets[i];
384 405 : auto s_type = s.get_type();
385 405 : auto n_elmts = s.get_databytes() / (size_t)s.get_datatype_size();
386 405 : std::cout << rang::style::bold << rang::fg::blue << (s_type == socket_t::SIN ? "const " : "")
387 405 : << s.get_datatype_string() << rang::style::reset << " " << s.get_name() << "["
388 810 : << (n_fra > 1 ? std::to_string(n_fra) + "x" : "") << (n_elmts / n_fra) << "]"
389 405 : << (i < (int)sockets.size() - 2 ? ", " : "");
390 :
391 405 : max_n_chars = std::max(s.get_name().size(), max_n_chars);
392 : }
393 335 : std::cout << ")" << std::endl;
394 :
395 1075 : for (auto& s : sockets)
396 : {
397 740 : auto s_type = s->get_type();
398 740 : if (s_type == socket_t::SIN || s_type == socket_t::SFWD)
399 : {
400 218 : std::string spaces;
401 346 : for (size_t ss = 0; ss < max_n_chars - s->get_name().size(); ss++)
402 128 : spaces += " ";
403 :
404 218 : auto n_elmts = s->get_databytes() / (size_t)s->get_datatype_size();
405 218 : auto fra_size = n_elmts / n_fra;
406 218 : auto limit = debug_limit != -1 ? std::min(fra_size, (size_t)debug_limit) : fra_size;
407 218 : auto max_frame = debug_frame_max != -1 ? std::min(n_fra, (size_t)debug_frame_max) : n_fra;
408 218 : auto p = debug_precision;
409 218 : auto h = debug_hex;
410 218 : std::cout << "# {IN} " << s->get_name() << spaces << " = [";
411 218 : if (s->get_datatype() == typeid(int8_t))
412 0 : display_data(s->get_dataptr<const int8_t>(),
413 : fra_size,
414 : n_fra,
415 : n_fra_per_w,
416 : limit,
417 : max_frame,
418 : p,
419 0 : (uint8_t)max_n_chars + 12,
420 : h);
421 218 : else if (s->get_datatype() == typeid(uint8_t))
422 173 : display_data(s->get_dataptr<const uint8_t>(),
423 : fra_size,
424 : n_fra,
425 : n_fra_per_w,
426 : limit,
427 : max_frame,
428 : p,
429 173 : (uint8_t)max_n_chars + 12,
430 : h);
431 45 : else if (s->get_datatype() == typeid(int16_t))
432 0 : display_data(s->get_dataptr<const int16_t>(),
433 : fra_size,
434 : n_fra,
435 : n_fra_per_w,
436 : limit,
437 : max_frame,
438 : p,
439 0 : (uint8_t)max_n_chars + 12,
440 : h);
441 45 : else if (s->get_datatype() == typeid(uint16_t))
442 0 : display_data(s->get_dataptr<const uint16_t>(),
443 : fra_size,
444 : n_fra,
445 : n_fra_per_w,
446 : limit,
447 : max_frame,
448 : p,
449 0 : (uint8_t)max_n_chars + 12,
450 : h);
451 45 : else if (s->get_datatype() == typeid(int32_t))
452 0 : display_data(s->get_dataptr<const int32_t>(),
453 : fra_size,
454 : n_fra,
455 : n_fra_per_w,
456 : limit,
457 : max_frame,
458 : p,
459 0 : (uint8_t)max_n_chars + 12,
460 : h);
461 45 : else if (s->get_datatype() == typeid(uint32_t))
462 27 : display_data(s->get_dataptr<const uint32_t>(),
463 : fra_size,
464 : n_fra,
465 : n_fra_per_w,
466 : limit,
467 : max_frame,
468 : p,
469 27 : (uint8_t)max_n_chars + 12,
470 : h);
471 18 : else if (s->get_datatype() == typeid(int64_t))
472 0 : display_data(s->get_dataptr<const int64_t>(),
473 : fra_size,
474 : n_fra,
475 : n_fra_per_w,
476 : limit,
477 : max_frame,
478 : p,
479 0 : (uint8_t)max_n_chars + 12,
480 : h);
481 18 : else if (s->get_datatype() == typeid(uint64_t))
482 18 : display_data(s->get_dataptr<const uint64_t>(),
483 : fra_size,
484 : n_fra,
485 : n_fra_per_w,
486 : limit,
487 : max_frame,
488 : p,
489 18 : (uint8_t)max_n_chars + 12,
490 : h);
491 0 : else if (s->get_datatype() == typeid(float))
492 0 : display_data(s->get_dataptr<const float>(),
493 : fra_size,
494 : n_fra,
495 : n_fra_per_w,
496 : limit,
497 : max_frame,
498 : p,
499 0 : (uint8_t)max_n_chars + 12,
500 : h);
501 0 : else if (s->get_datatype() == typeid(double))
502 0 : display_data(s->get_dataptr<const double>(),
503 : fra_size,
504 : n_fra,
505 : n_fra_per_w,
506 : limit,
507 : max_frame,
508 : p,
509 0 : (uint8_t)max_n_chars + 12,
510 : h);
511 218 : std::cout << "]" << std::endl;
512 218 : }
513 : }
514 335 : }
515 :
516 8404 : if (this->is_stats())
517 : {
518 8033 : auto t_start = std::chrono::steady_clock::now();
519 8054 : this->_exec(frame_id, managed_memory);
520 7951 : auto duration = std::chrono::steady_clock::now() - t_start;
521 :
522 7940 : this->duration_total += duration;
523 7926 : if (n_calls)
524 : {
525 7459 : this->duration_min = std::min(this->duration_min, duration);
526 7468 : this->duration_max = std::max(this->duration_max, duration);
527 : }
528 : else
529 : {
530 467 : this->duration_min = duration;
531 467 : this->duration_max = duration;
532 : }
533 : }
534 : else
535 : {
536 365 : this->_exec(frame_id, managed_memory);
537 : }
538 8316 : this->n_calls++;
539 :
540 8316 : if (this->is_debug())
541 : {
542 332 : auto n_fra = this->module->get_n_frames();
543 332 : auto n_fra_per_w = this->module->get_n_frames_per_wave();
544 1064 : for (auto& s : sockets)
545 : {
546 732 : auto s_type = s->get_type();
547 731 : if ((s_type == socket_t::SOUT) && s->get_name() != "status")
548 : {
549 185 : std::string spaces;
550 203 : for (size_t ss = 0; ss < max_n_chars - s->get_name().size(); ss++)
551 18 : spaces += " ";
552 :
553 185 : auto n_elmts = s->get_databytes() / (size_t)s->get_datatype_size();
554 185 : auto fra_size = n_elmts / n_fra;
555 185 : auto limit = debug_limit != -1 ? std::min(fra_size, (size_t)debug_limit) : fra_size;
556 185 : auto max_frame = debug_frame_max != -1 ? std::min(n_fra, (size_t)debug_frame_max) : n_fra;
557 185 : auto p = debug_precision;
558 185 : auto h = debug_hex;
559 185 : std::cout << "# {OUT} " << s->get_name() << spaces << " = [";
560 185 : if (s->get_datatype() == typeid(int8_t))
561 0 : display_data(s->get_dataptr<const int8_t>(),
562 : fra_size,
563 : n_fra,
564 : n_fra_per_w,
565 : limit,
566 : max_frame,
567 : p,
568 0 : (uint8_t)max_n_chars + 12,
569 : h);
570 185 : else if (s->get_datatype() == typeid(uint8_t))
571 149 : display_data(s->get_dataptr<const uint8_t>(),
572 : fra_size,
573 : n_fra,
574 : n_fra_per_w,
575 : limit,
576 : max_frame,
577 : p,
578 149 : (uint8_t)max_n_chars + 12,
579 : h);
580 36 : else if (s->get_datatype() == typeid(int16_t))
581 0 : display_data(s->get_dataptr<const int16_t>(),
582 : fra_size,
583 : n_fra,
584 : n_fra_per_w,
585 : limit,
586 : max_frame,
587 : p,
588 0 : (uint8_t)max_n_chars + 12,
589 : h);
590 36 : else if (s->get_datatype() == typeid(uint16_t))
591 0 : display_data(s->get_dataptr<const uint16_t>(),
592 : fra_size,
593 : n_fra,
594 : n_fra_per_w,
595 : limit,
596 : max_frame,
597 : p,
598 0 : (uint8_t)max_n_chars + 12,
599 : h);
600 36 : else if (s->get_datatype() == typeid(int32_t))
601 0 : display_data(s->get_dataptr<const int32_t>(),
602 : fra_size,
603 : n_fra,
604 : n_fra_per_w,
605 : limit,
606 : max_frame,
607 : p,
608 0 : (uint8_t)max_n_chars + 12,
609 : h);
610 36 : else if (s->get_datatype() == typeid(uint32_t))
611 18 : display_data(s->get_dataptr<const uint32_t>(),
612 : fra_size,
613 : n_fra,
614 : n_fra_per_w,
615 : limit,
616 : max_frame,
617 : p,
618 18 : (uint8_t)max_n_chars + 12,
619 : h);
620 18 : else if (s->get_datatype() == typeid(int64_t))
621 0 : display_data(s->get_dataptr<const int64_t>(),
622 : fra_size,
623 : n_fra,
624 : n_fra_per_w,
625 : limit,
626 : max_frame,
627 : p,
628 0 : (uint8_t)max_n_chars + 12,
629 : h);
630 18 : else if (s->get_datatype() == typeid(uint64_t))
631 18 : display_data(s->get_dataptr<const uint64_t>(),
632 : fra_size,
633 : n_fra,
634 : n_fra_per_w,
635 : limit,
636 : max_frame,
637 : p,
638 18 : (uint8_t)max_n_chars + 12,
639 : h);
640 0 : else if (s->get_datatype() == typeid(float))
641 0 : display_data(s->get_dataptr<const float>(),
642 : fra_size,
643 : n_fra,
644 : n_fra_per_w,
645 : limit,
646 : max_frame,
647 : p,
648 0 : (uint8_t)max_n_chars + 12,
649 : h);
650 0 : else if (s->get_datatype() == typeid(double))
651 0 : display_data(s->get_dataptr<const double>(),
652 : fra_size,
653 : n_fra,
654 : n_fra_per_w,
655 : limit,
656 : max_frame,
657 : p,
658 0 : (uint8_t)max_n_chars + 12,
659 : h);
660 185 : std::cout << "]" << std::endl;
661 185 : }
662 : }
663 332 : std::cout << "# Returned status: [";
664 : // do not use 'this->status' because the dataptr can have been changed by the 'tools::Sequence' when using
665 : // the no copy mode
666 332 : int* status = this->sockets.back()->get_dataptr<int>();
667 664 : for (size_t w = 0; w < this->get_module().get_n_waves(); w++)
668 : {
669 332 : if (status_t_to_string.count(status[w]))
670 332 : std::cout << ((w != 0) ? ", " : "") << std::dec << status[w] << " '"
671 332 : << status_t_to_string[status[w]] << "'";
672 : else
673 0 : std::cout << ((w != 0) ? ", " : "") << std::dec << status[w];
674 : }
675 332 : std::cout << "]" << std::endl;
676 332 : std::cout << "#" << std::noshowbase << std::endl;
677 : }
678 :
679 : // if (exec_status < 0)
680 : // {
681 : // std::stringstream message;
682 : // message << "'exec_status' can't be negative ('exec_status' = " << exec_status << ").";
683 : // throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
684 : // }
685 :
686 8307 : return this->get_status();
687 : }
688 : else
689 : {
690 0 : std::stringstream socs;
691 0 : socs << "'socket(s).name' = [";
692 0 : auto s = 0;
693 0 : for (size_t i = 0; i < sockets.size(); i++)
694 0 : if (sockets[i]->dataptr == nullptr) socs << (s != 0 ? ", " : "") << sockets[i]->name;
695 0 : socs << "]";
696 :
697 0 : std::stringstream message;
698 : message << "The task cannot be executed because some of the inputs/output sockets are not fed ('task.name' = "
699 0 : << this->get_name() << ", 'module.name' = " << module->get_name() << ", " << socs.str() << ").";
700 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
701 0 : }
702 : #endif /* !SPU_FAST */
703 : }
704 :
705 : template<typename T>
706 : Socket&
707 8453 : Task::create_2d_socket(const std::string& name,
708 : const size_t n_rows,
709 : const size_t n_cols,
710 : const socket_t type,
711 : const bool hack_status)
712 : {
713 8453 : if (name.empty())
714 : {
715 0 : std::stringstream message;
716 : message << "Impossible to create this socket because the name is empty ('task.name' = " << this->get_name()
717 0 : << ", 'module.name' = " << module->get_name() << ").";
718 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
719 0 : }
720 :
721 8453 : if (name == "status" && !hack_status)
722 : {
723 0 : std::stringstream message;
724 0 : message << "A socket can't be named 'status'.";
725 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
726 0 : }
727 :
728 15801 : for (auto& s : sockets)
729 7348 : if (s->get_name() == name)
730 : {
731 0 : std::stringstream message;
732 : message << "Impossible to create this socket because an other socket has the same name ('socket.name' = "
733 0 : << name << ", 'task.name' = " << this->get_name() << ", 'module.name' = " << module->get_name()
734 0 : << ").";
735 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
736 0 : }
737 :
738 15801 : for (auto s : this->sockets)
739 7348 : if (s->get_name() == "status")
740 : {
741 0 : std::stringstream message;
742 0 : message << "Creating new sockets after the 'status' socket is forbidden.";
743 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
744 0 : }
745 :
746 8453 : std::pair<size_t, size_t> databytes_per_dim = { n_rows, n_cols * sizeof(T) };
747 8453 : auto s = std::make_shared<Socket>(*this, name, typeid(T), databytes_per_dim, type, this->is_fast());
748 :
749 8453 : sockets.push_back(std::move(s));
750 :
751 8453 : this->sockets_dataptr_init.push_back(nullptr);
752 8453 : this->sockets_databytes_per_frame.push_back(sockets.back()->get_databytes() / this->get_module().get_n_frames());
753 8453 : this->sockets_data.push_back(
754 16906 : std::vector<int8_t>((this->get_module().get_n_frames_per_wave() > 1)
755 0 : ? this->sockets_databytes_per_frame.back() * this->get_module().get_n_frames_per_wave()
756 : : 0));
757 :
758 16906 : return *sockets.back();
759 8453 : }
760 :
761 : template<typename T>
762 : size_t
763 2062 : Task::create_2d_socket_in(const std::string& name, const size_t n_rows, const size_t n_cols)
764 : {
765 2062 : auto& s = create_2d_socket<T>(name, n_rows, n_cols, socket_t::SIN);
766 2062 : last_input_socket = &s;
767 :
768 2062 : this->n_input_sockets++;
769 :
770 2062 : return sockets.size() - 1;
771 : }
772 :
773 : size_t
774 674 : Task::create_2d_socket_in(const std::string& name,
775 : const size_t n_rows,
776 : const size_t n_cols,
777 : const std::type_index& datatype)
778 : {
779 674 : if (datatype == typeid(int8_t))
780 66 : return this->template create_2d_socket_in<int8_t>(name, n_rows, n_cols);
781 608 : else if (datatype == typeid(uint8_t))
782 455 : return this->template create_2d_socket_in<uint8_t>(name, n_rows, n_cols);
783 153 : else if (datatype == typeid(int16_t))
784 0 : return this->template create_2d_socket_in<int16_t>(name, n_rows, n_cols);
785 153 : else if (datatype == typeid(uint16_t))
786 0 : return this->template create_2d_socket_in<uint16_t>(name, n_rows, n_cols);
787 153 : else if (datatype == typeid(int32_t))
788 30 : return this->template create_2d_socket_in<int32_t>(name, n_rows, n_cols);
789 123 : else if (datatype == typeid(uint32_t))
790 107 : return this->template create_2d_socket_in<uint32_t>(name, n_rows, n_cols);
791 16 : else if (datatype == typeid(int64_t))
792 0 : return this->template create_2d_socket_in<int64_t>(name, n_rows, n_cols);
793 16 : else if (datatype == typeid(uint64_t))
794 16 : return this->template create_2d_socket_in<uint64_t>(name, n_rows, n_cols);
795 0 : else if (datatype == typeid(float))
796 0 : return this->template create_2d_socket_in<float>(name, n_rows, n_cols);
797 0 : else if (datatype == typeid(double))
798 0 : return this->template create_2d_socket_in<double>(name, n_rows, n_cols);
799 : else
800 : {
801 0 : std::stringstream message;
802 0 : message << "This should never happen.";
803 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
804 0 : }
805 : }
806 :
807 : size_t
808 0 : Task::create_2d_socket_in(const std::string& name, const size_t n_rows, const size_t n_cols, const datatype_t datatype)
809 : {
810 0 : switch (datatype)
811 : {
812 0 : case datatype_t::F64:
813 0 : return this->template create_2d_socket_in<double>(name, n_rows, n_cols);
814 : break;
815 0 : case datatype_t::F32:
816 0 : return this->template create_2d_socket_in<float>(name, n_rows, n_cols);
817 : break;
818 0 : case datatype_t::S64:
819 0 : return this->template create_2d_socket_in<int64_t>(name, n_rows, n_cols);
820 : break;
821 0 : case datatype_t::S32:
822 0 : return this->template create_2d_socket_in<int32_t>(name, n_rows, n_cols);
823 : break;
824 0 : case datatype_t::S16:
825 0 : return this->template create_2d_socket_in<int16_t>(name, n_rows, n_cols);
826 : break;
827 0 : case datatype_t::S8:
828 0 : return this->template create_2d_socket_in<int8_t>(name, n_rows, n_cols);
829 : break;
830 0 : case datatype_t::U64:
831 0 : return this->template create_2d_socket_in<uint64_t>(name, n_rows, n_cols);
832 : break;
833 0 : case datatype_t::U32:
834 0 : return this->template create_2d_socket_in<uint32_t>(name, n_rows, n_cols);
835 : break;
836 0 : case datatype_t::U16:
837 0 : return this->template create_2d_socket_in<uint16_t>(name, n_rows, n_cols);
838 : break;
839 0 : case datatype_t::U8:
840 0 : return this->template create_2d_socket_in<uint8_t>(name, n_rows, n_cols);
841 : break;
842 0 : default:
843 : {
844 0 : std::stringstream message;
845 0 : message << "This should never happen.";
846 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
847 : break;
848 0 : }
849 : }
850 : }
851 :
852 : template<typename T>
853 : size_t
854 5339 : Task::create_2d_socket_out(const std::string& name, const size_t n_rows, const size_t n_cols, const bool hack_status)
855 : {
856 5339 : create_2d_socket<T>(name, n_rows, n_cols, socket_t::SOUT, hack_status);
857 5339 : this->n_output_sockets++;
858 :
859 5339 : return sockets.size() - 1;
860 : }
861 :
862 : size_t
863 669 : Task::create_2d_socket_out(const std::string& name,
864 : const size_t n_rows,
865 : const size_t n_cols,
866 : const std::type_index& datatype,
867 : const bool hack_status)
868 : {
869 669 : if (datatype == typeid(int8_t))
870 61 : return this->template create_2d_socket_out<int8_t>(name, n_rows, n_cols, hack_status);
871 608 : else if (datatype == typeid(uint8_t))
872 455 : return this->template create_2d_socket_out<uint8_t>(name, n_rows, n_cols, hack_status);
873 153 : else if (datatype == typeid(int16_t))
874 0 : return this->template create_2d_socket_out<int16_t>(name, n_rows, n_cols, hack_status);
875 153 : else if (datatype == typeid(uint16_t))
876 0 : return this->template create_2d_socket_out<uint16_t>(name, n_rows, n_cols, hack_status);
877 153 : else if (datatype == typeid(int32_t))
878 30 : return this->template create_2d_socket_out<int32_t>(name, n_rows, n_cols, hack_status);
879 123 : else if (datatype == typeid(uint32_t))
880 107 : return this->template create_2d_socket_out<uint32_t>(name, n_rows, n_cols, hack_status);
881 16 : else if (datatype == typeid(int64_t))
882 0 : return this->template create_2d_socket_out<int64_t>(name, n_rows, n_cols, hack_status);
883 16 : else if (datatype == typeid(uint64_t))
884 16 : return this->template create_2d_socket_out<uint64_t>(name, n_rows, n_cols, hack_status);
885 0 : else if (datatype == typeid(float))
886 0 : return this->template create_2d_socket_out<float>(name, n_rows, n_cols, hack_status);
887 0 : else if (datatype == typeid(double))
888 0 : return this->template create_2d_socket_out<double>(name, n_rows, n_cols, hack_status);
889 : else
890 : {
891 0 : std::stringstream message;
892 0 : message << "This should never happen.";
893 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
894 0 : }
895 : }
896 :
897 : size_t
898 0 : Task::create_2d_socket_out(const std::string& name,
899 : const size_t n_rows,
900 : const size_t n_cols,
901 : const datatype_t datatype,
902 : const bool hack_status)
903 : {
904 0 : switch (datatype)
905 : {
906 0 : case datatype_t::F64:
907 0 : return this->template create_2d_socket_out<double>(name, n_rows, n_cols, hack_status);
908 : break;
909 0 : case datatype_t::F32:
910 0 : return this->template create_2d_socket_out<float>(name, n_rows, n_cols, hack_status);
911 : break;
912 0 : case datatype_t::S64:
913 0 : return this->template create_2d_socket_out<int64_t>(name, n_rows, n_cols, hack_status);
914 : break;
915 0 : case datatype_t::S32:
916 0 : return this->template create_2d_socket_out<int32_t>(name, n_rows, n_cols, hack_status);
917 : break;
918 0 : case datatype_t::S16:
919 0 : return this->template create_2d_socket_out<int16_t>(name, n_rows, n_cols, hack_status);
920 : break;
921 0 : case datatype_t::S8:
922 0 : return this->template create_2d_socket_out<int8_t>(name, n_rows, n_cols, hack_status);
923 : break;
924 0 : case datatype_t::U64:
925 0 : return this->template create_2d_socket_out<uint64_t>(name, n_rows, n_cols, hack_status);
926 : break;
927 0 : case datatype_t::U32:
928 0 : return this->template create_2d_socket_out<uint32_t>(name, n_rows, n_cols, hack_status);
929 : break;
930 0 : case datatype_t::U16:
931 0 : return this->template create_2d_socket_out<uint16_t>(name, n_rows, n_cols, hack_status);
932 : break;
933 0 : case datatype_t::U8:
934 0 : return this->template create_2d_socket_out<uint8_t>(name, n_rows, n_cols, hack_status);
935 : break;
936 0 : default:
937 : {
938 0 : std::stringstream message;
939 0 : message << "This should never happen.";
940 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
941 : break;
942 0 : }
943 : }
944 : }
945 :
946 : template<typename T>
947 : size_t
948 1052 : Task::create_2d_socket_fwd(const std::string& name, const size_t n_rows, const size_t n_cols)
949 : {
950 1052 : auto& s = create_2d_socket<T>(name, n_rows, n_cols, socket_t::SFWD);
951 1052 : last_input_socket = &s;
952 :
953 1052 : this->n_fwd_sockets++;
954 :
955 1052 : return sockets.size() - 1;
956 : }
957 :
958 : size_t
959 0 : Task::create_2d_socket_fwd(const std::string& name,
960 : const size_t n_rows,
961 : const size_t n_cols,
962 : const std::type_index& datatype)
963 : {
964 0 : if (datatype == typeid(int8_t))
965 0 : return this->template create_2d_socket_fwd<int8_t>(name, n_rows, n_cols);
966 0 : else if (datatype == typeid(uint8_t))
967 0 : return this->template create_2d_socket_fwd<uint8_t>(name, n_rows, n_cols);
968 0 : else if (datatype == typeid(int16_t))
969 0 : return this->template create_2d_socket_fwd<int16_t>(name, n_rows, n_cols);
970 0 : else if (datatype == typeid(uint16_t))
971 0 : return this->template create_2d_socket_fwd<uint16_t>(name, n_rows, n_cols);
972 0 : else if (datatype == typeid(int32_t))
973 0 : return this->template create_2d_socket_fwd<int32_t>(name, n_rows, n_cols);
974 0 : else if (datatype == typeid(uint32_t))
975 0 : return this->template create_2d_socket_fwd<uint32_t>(name, n_rows, n_cols);
976 0 : else if (datatype == typeid(int64_t))
977 0 : return this->template create_2d_socket_fwd<int64_t>(name, n_rows, n_cols);
978 0 : else if (datatype == typeid(uint64_t))
979 0 : return this->template create_2d_socket_fwd<uint64_t>(name, n_rows, n_cols);
980 0 : else if (datatype == typeid(float))
981 0 : return this->template create_2d_socket_fwd<float>(name, n_rows, n_cols);
982 0 : else if (datatype == typeid(double))
983 0 : return this->template create_2d_socket_fwd<double>(name, n_rows, n_cols);
984 : else
985 : {
986 0 : std::stringstream message;
987 0 : message << "This should never happen.";
988 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
989 0 : }
990 : }
991 :
992 : size_t
993 0 : Task::create_2d_socket_fwd(const std::string& name, const size_t n_rows, const size_t n_cols, const datatype_t datatype)
994 : {
995 0 : switch (datatype)
996 : {
997 0 : case datatype_t::F64:
998 0 : return this->template create_2d_socket_fwd<double>(name, n_rows, n_cols);
999 : break;
1000 0 : case datatype_t::F32:
1001 0 : return this->template create_2d_socket_fwd<float>(name, n_rows, n_cols);
1002 : break;
1003 0 : case datatype_t::S64:
1004 0 : return this->template create_2d_socket_fwd<int64_t>(name, n_rows, n_cols);
1005 : break;
1006 0 : case datatype_t::S32:
1007 0 : return this->template create_2d_socket_fwd<int32_t>(name, n_rows, n_cols);
1008 : break;
1009 0 : case datatype_t::S16:
1010 0 : return this->template create_2d_socket_fwd<int16_t>(name, n_rows, n_cols);
1011 : break;
1012 0 : case datatype_t::S8:
1013 0 : return this->template create_2d_socket_fwd<int8_t>(name, n_rows, n_cols);
1014 : break;
1015 0 : case datatype_t::U64:
1016 0 : return this->template create_2d_socket_fwd<uint64_t>(name, n_rows, n_cols);
1017 : break;
1018 0 : case datatype_t::U32:
1019 0 : return this->template create_2d_socket_fwd<uint32_t>(name, n_rows, n_cols);
1020 : break;
1021 0 : case datatype_t::U16:
1022 0 : return this->template create_2d_socket_fwd<uint16_t>(name, n_rows, n_cols);
1023 : break;
1024 0 : case datatype_t::U8:
1025 0 : return this->template create_2d_socket_fwd<uint8_t>(name, n_rows, n_cols);
1026 : break;
1027 0 : default:
1028 : {
1029 0 : std::stringstream message;
1030 0 : message << "This should never happen.";
1031 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1032 : break;
1033 0 : }
1034 : }
1035 : }
1036 :
1037 : void
1038 3398 : Task::create_codelet(std::function<int(module::Module& m, Task& t, const size_t frame_id)>& codelet)
1039 : {
1040 3398 : this->codelet = codelet;
1041 :
1042 : // create automatically a socket that contains the status of the task
1043 3398 : const bool hack_status = true;
1044 3398 : auto s = this->template create_2d_socket_out<int>("status", 1, this->get_module().get_n_waves(), hack_status);
1045 3398 : this->sockets[s]->dataptr = (void*)this->status.data();
1046 3398 : }
1047 :
1048 : void
1049 82808 : Task::update_n_frames(const size_t old_n_frames, const size_t new_n_frames)
1050 : {
1051 296908 : for (auto& s : this->sockets)
1052 : {
1053 214100 : if (s->get_name() == "status")
1054 : {
1055 82808 : if (this->get_module().get_n_waves() * sizeof(int) != s->get_databytes())
1056 : {
1057 73282 : s->set_databytes(this->get_module().get_n_waves() * sizeof(int));
1058 73282 : this->status.resize(this->get_module().get_n_waves());
1059 73282 : s->set_dataptr((void*)this->status.data());
1060 : }
1061 : }
1062 : else
1063 : {
1064 131292 : const auto old_databytes = s->get_databytes();
1065 131292 : const auto new_databytes = (old_databytes / old_n_frames) * new_n_frames;
1066 131292 : s->set_databytes(new_databytes);
1067 :
1068 131292 : const size_t prev_n_rows_wo_nfra = s->get_n_rows() / old_n_frames;
1069 131292 : s->set_n_rows(prev_n_rows_wo_nfra * new_n_frames);
1070 :
1071 131292 : if (s->get_type() == socket_t::SOUT)
1072 : {
1073 51842 : s->resize_out_buffer(new_databytes);
1074 : }
1075 : }
1076 : }
1077 82808 : }
1078 :
1079 : void
1080 19052 : Task::update_n_frames_per_wave(const size_t /*old_n_frames_per_wave*/, const size_t new_n_frames_per_wave)
1081 : {
1082 19052 : size_t s_id = 0;
1083 77680 : for (auto& s : this->sockets)
1084 : {
1085 58628 : if (s->get_name() == "status")
1086 : {
1087 19052 : if (this->get_module().get_n_waves() * sizeof(int) != s->get_databytes())
1088 : {
1089 9526 : s->set_databytes(this->get_module().get_n_waves() * sizeof(int));
1090 9526 : this->status.resize(this->get_module().get_n_waves());
1091 9526 : s->set_dataptr((void*)this->status.data());
1092 : }
1093 : }
1094 : else
1095 : {
1096 59364 : this->sockets_data[s_id].resize(
1097 19788 : (new_n_frames_per_wave > 1) ? this->sockets_databytes_per_frame[s_id] * new_n_frames_per_wave : 0);
1098 : }
1099 58628 : s_id++;
1100 : }
1101 19052 : }
1102 :
1103 : void
1104 44609 : Task::allocate_outbuffers()
1105 : {
1106 44609 : if (!this->is_outbuffers_allocated())
1107 : {
1108 : std::function<void(Socket * socket, void* data_ptr)> spread_dataptr =
1109 49496 : [&spread_dataptr](Socket* socket, void* data_ptr)
1110 : {
1111 87888 : for (auto bound_socket : socket->get_bound_sockets())
1112 : {
1113 44484 : if (bound_socket->get_type() == socket_t::SIN)
1114 : {
1115 38392 : bound_socket->set_dataptr(data_ptr);
1116 : }
1117 6092 : else if (bound_socket->get_type() == socket_t::SFWD)
1118 : {
1119 6092 : bound_socket->set_dataptr(data_ptr);
1120 6092 : spread_dataptr(bound_socket, data_ptr);
1121 : }
1122 : else
1123 : {
1124 0 : std::stringstream message;
1125 0 : message << "bound socket is of type SOUT, but should be SIN or SFWD";
1126 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1127 0 : }
1128 : }
1129 87085 : };
1130 168429 : for (auto s : this->sockets)
1131 : {
1132 124748 : if (s->get_type() == socket_t::SOUT && s->get_name() != "status")
1133 : {
1134 37360 : if (s->get_dataptr() == nullptr)
1135 : {
1136 37312 : s->allocate_buffer();
1137 37312 : spread_dataptr(s.get(), s->get_dataptr());
1138 : }
1139 : }
1140 124748 : }
1141 43681 : this->set_outbuffers_allocated(true);
1142 43681 : }
1143 44609 : }
1144 : void
1145 38519 : Task::deallocate_outbuffers()
1146 : {
1147 38519 : if (!this->is_outbuffers_allocated())
1148 : {
1149 0 : std::stringstream message;
1150 : message << "Task out sockets buffers are not allocated"
1151 0 : << ", task name : " << this->get_name();
1152 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1153 0 : }
1154 41153 : std::function<void(Socket * socket)> spread_nullptr = [&spread_nullptr](Socket* socket)
1155 : {
1156 70819 : for (auto bound_socket : socket->get_bound_sockets())
1157 : {
1158 33930 : if (bound_socket->get_type() == socket_t::SIN)
1159 : {
1160 29619 : bound_socket->set_dataptr(nullptr);
1161 : }
1162 4311 : else if (bound_socket->get_type() == socket_t::SFWD)
1163 : {
1164 4264 : bound_socket->set_dataptr(nullptr);
1165 4264 : spread_nullptr(bound_socket);
1166 : }
1167 47 : else if (dynamic_cast<const module::Set*>(&bound_socket->get_task().get_module()))
1168 : {
1169 : // hack: for set that bind SOUT to SOUT for perf
1170 47 : bound_socket->set_dataptr(nullptr);
1171 : }
1172 : else
1173 : {
1174 0 : std::stringstream message;
1175 0 : message << "bound socket is of type SOUT, but should be SIN or SFWD";
1176 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1177 0 : }
1178 : }
1179 75408 : };
1180 147608 : for (auto s : this->sockets)
1181 : {
1182 109089 : if (s->get_type() == socket_t::SOUT && s->get_name() != "status")
1183 : {
1184 32625 : if (s->get_dataptr() != nullptr)
1185 : {
1186 32625 : s->deallocate_buffer();
1187 32625 : spread_nullptr(s.get());
1188 : }
1189 : }
1190 109089 : }
1191 38519 : this->set_outbuffers_allocated(false);
1192 38519 : }
1193 :
1194 : bool
1195 382 : Task::can_exec() const
1196 : {
1197 1464 : for (size_t i = 0; i < sockets.size(); i++)
1198 1057 : if (sockets[i]->dataptr == nullptr) return false;
1199 364 : return true;
1200 : }
1201 :
1202 : std::chrono::nanoseconds
1203 11386 : Task::get_duration_total() const
1204 : {
1205 11386 : return this->duration_total;
1206 : }
1207 :
1208 : std::chrono::nanoseconds
1209 1360 : Task::get_duration_avg() const
1210 : {
1211 1360 : return this->duration_total / this->n_calls;
1212 : }
1213 :
1214 : std::chrono::nanoseconds
1215 1532 : Task::get_duration_min() const
1216 : {
1217 1532 : return this->duration_min;
1218 : }
1219 :
1220 : std::chrono::nanoseconds
1221 1532 : Task::get_duration_max() const
1222 : {
1223 1532 : return this->duration_max;
1224 : }
1225 :
1226 : const std::vector<std::string>&
1227 385 : Task::get_timers_name() const
1228 : {
1229 385 : return this->timers_name;
1230 : }
1231 :
1232 : const std::vector<uint32_t>&
1233 59 : Task::get_timers_n_calls() const
1234 : {
1235 59 : return this->timers_n_calls;
1236 : }
1237 :
1238 : const std::vector<std::chrono::nanoseconds>&
1239 59 : Task::get_timers_total() const
1240 : {
1241 59 : return this->timers_total;
1242 : }
1243 :
1244 : const std::vector<std::chrono::nanoseconds>&
1245 59 : Task::get_timers_min() const
1246 : {
1247 59 : return this->timers_min;
1248 : }
1249 :
1250 : const std::vector<std::chrono::nanoseconds>&
1251 59 : Task::get_timers_max() const
1252 : {
1253 59 : return this->timers_max;
1254 : }
1255 :
1256 : size_t
1257 2634 : Task::get_n_input_sockets() const
1258 : {
1259 2634 : return this->n_input_sockets;
1260 : }
1261 :
1262 : size_t
1263 0 : Task::get_n_output_sockets() const
1264 : {
1265 0 : return this->n_output_sockets;
1266 : }
1267 :
1268 : size_t
1269 3783 : Task::get_n_fwd_sockets() const
1270 : {
1271 3783 : return this->n_fwd_sockets;
1272 : }
1273 :
1274 : void
1275 0 : Task::register_timer(const std::string& name)
1276 : {
1277 0 : this->timers_name.push_back(name);
1278 0 : this->timers_n_calls.push_back(0);
1279 0 : this->timers_total.push_back(std::chrono::nanoseconds(0));
1280 0 : this->timers_max.push_back(std::chrono::nanoseconds(0));
1281 0 : this->timers_min.push_back(std::chrono::nanoseconds(0));
1282 0 : }
1283 :
1284 : void
1285 80207 : Task::reset()
1286 : {
1287 80207 : this->n_calls = 0;
1288 80207 : this->duration_total = std::chrono::nanoseconds(0);
1289 80207 : this->duration_min = std::chrono::nanoseconds(0);
1290 80207 : this->duration_max = std::chrono::nanoseconds(0);
1291 :
1292 80207 : for (auto& x : this->timers_n_calls)
1293 0 : x = 0;
1294 80207 : for (auto& x : this->timers_total)
1295 0 : x = std::chrono::nanoseconds(0);
1296 80207 : for (auto& x : this->timers_min)
1297 0 : x = std::chrono::nanoseconds(0);
1298 80207 : for (auto& x : this->timers_max)
1299 0 : x = std::chrono::nanoseconds(0);
1300 80207 : }
1301 :
1302 : Task*
1303 73832 : Task::clone() const
1304 : {
1305 73832 : Task* t = new Task(*this);
1306 73832 : t->sockets.clear();
1307 73832 : t->last_input_socket = nullptr;
1308 73832 : t->fake_input_sockets.clear();
1309 73832 : t->set_outbuffers_allocated(false);
1310 :
1311 266751 : for (auto s : this->sockets)
1312 : {
1313 192919 : void* dataptr = nullptr;
1314 192919 : if (s->get_type() == socket_t::SOUT)
1315 : {
1316 121087 : if (s->get_name() == "status")
1317 : {
1318 73832 : dataptr = (void*)t->status.data();
1319 : }
1320 : }
1321 71832 : else if (s->get_type() == socket_t::SIN || s->get_type() == socket_t::SFWD)
1322 71832 : dataptr = s->_get_dataptr();
1323 :
1324 : // No need to allocate memory when cloning
1325 192919 : const std::pair<size_t, size_t> databytes_per_dim = { s->get_n_rows(), s->get_databytes() / s->get_n_rows() };
1326 : auto s_new = std::shared_ptr<Socket>(
1327 192919 : new Socket(*t, s->get_name(), s->get_datatype(), databytes_per_dim, s->get_type(), s->is_fast(), dataptr));
1328 192919 : t->sockets.push_back(s_new);
1329 :
1330 192919 : if (s_new->get_type() == socket_t::SIN || s_new->get_type() == socket_t::SFWD)
1331 71832 : t->last_input_socket = s_new.get();
1332 192919 : }
1333 :
1334 73832 : return t;
1335 : }
1336 :
1337 : void
1338 9240 : Task::_bind(Socket& s_out, const int priority)
1339 : {
1340 : // check if the 's_out' socket is already used for an other fake input socket
1341 9240 : bool already_bound = false;
1342 9240 : for (auto& fsi : this->fake_input_sockets)
1343 0 : if (&fsi->get_bound_socket() == &s_out)
1344 : {
1345 0 : already_bound = true;
1346 0 : break;
1347 : }
1348 :
1349 : // check if the 's_out' socket is already used for an other read input/fwd socket
1350 9240 : if (!already_bound)
1351 28130 : for (auto& s : this->sockets)
1352 18890 : if (s->get_type() == socket_t::SIN || s->get_type() == socket_t::SFWD)
1353 : {
1354 : try // because 's->get_bound_socket()' can throw if s->bound_socket == 'nullptr'
1355 : {
1356 1547 : if (&s->get_bound_socket() == &s_out)
1357 : {
1358 0 : already_bound = true;
1359 0 : break;
1360 : }
1361 : }
1362 692 : catch (...)
1363 : {
1364 692 : }
1365 : }
1366 :
1367 : // if the s_out socket is not already bound, then create a new fake input socket
1368 9240 : if (!already_bound)
1369 : {
1370 9240 : this->fake_input_sockets.push_back(
1371 18480 : std::shared_ptr<Socket>(new Socket(*this,
1372 18480 : "fake" + std::to_string(this->fake_input_sockets.size()),
1373 9240 : s_out.get_datatype(),
1374 9240 : s_out.get_databytes(),
1375 : socket_t::SIN,
1376 9240 : this->is_fast())));
1377 9240 : this->fake_input_sockets.back()->_bind(s_out, priority);
1378 9240 : this->last_input_socket = this->fake_input_sockets.back().get();
1379 9240 : this->n_input_sockets++;
1380 : }
1381 9240 : }
1382 :
1383 : void
1384 0 : Task::bind(Socket& s_out, const int priority)
1385 : {
1386 : #ifdef SPU_SHOW_DEPRECATED
1387 : std::clog << rang::tag::warning << "Deprecated: 'Task::bind()' should be replaced by 'Task::operator='."
1388 : << std::endl;
1389 : #ifdef SPU_STACKTRACE
1390 : #ifdef SPU_COLORS
1391 : bool enable_color = true;
1392 : #else
1393 : bool enable_color = false;
1394 : #endif
1395 : cpptrace::generate_trace().print(std::clog, enable_color);
1396 : #endif
1397 : #endif
1398 0 : this->_bind(s_out, priority);
1399 0 : }
1400 :
1401 : void
1402 0 : Task::_bind(Task& t_out, const int priority)
1403 : {
1404 0 : this->_bind(*t_out.sockets.back(), priority);
1405 0 : }
1406 :
1407 : void
1408 0 : Task::bind(Task& t_out, const int priority)
1409 : {
1410 : #ifdef SPU_SHOW_DEPRECATED
1411 : std::clog << rang::tag::warning << "Deprecated: 'Task::bind()' should be replaced by 'Task::operator='."
1412 : << std::endl;
1413 : #ifdef SPU_STACKTRACE
1414 : #ifdef SPU_COLORS
1415 : bool enable_color = true;
1416 : #else
1417 : bool enable_color = false;
1418 : #endif
1419 : cpptrace::generate_trace().print(std::clog, enable_color);
1420 : #endif
1421 : #endif
1422 0 : this->_bind(t_out, priority);
1423 0 : }
1424 :
1425 : void
1426 7377 : Task::operator=(Socket& s_out)
1427 : {
1428 : #ifndef SPU_FAST
1429 7377 : if (s_out.get_type() == socket_t::SOUT || s_out.get_type() == socket_t::SFWD)
1430 : #endif
1431 7377 : this->_bind(s_out);
1432 : #ifndef SPU_FAST
1433 : else
1434 : {
1435 0 : std::stringstream message;
1436 : message << "'s_out' should be and output socket ("
1437 0 : << "'s_out.datatype' = " << type_to_string[s_out.get_datatype()] << ", "
1438 0 : << "'s_out.name' = " << s_out.get_name() << ", "
1439 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1440 0 : << "'s_out.type' = " << (s_out.get_type() == socket_t::SIN ? "SIN" : "SOUT") << ", "
1441 0 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ").";
1442 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1443 0 : }
1444 : #endif
1445 7377 : }
1446 :
1447 : void
1448 300 : Task::operator=(Task& t_out)
1449 : {
1450 300 : (*this) = *t_out.sockets.back();
1451 300 : }
1452 :
1453 : size_t
1454 51453 : Task::unbind(Socket& s_out)
1455 : {
1456 51453 : if (this->fake_input_sockets.size())
1457 : {
1458 6277 : size_t i = 0;
1459 6317 : for (auto& fsi : this->fake_input_sockets)
1460 : {
1461 6277 : if (&fsi->get_bound_socket() == &s_out)
1462 : {
1463 6237 : const auto pos = fsi->unbind(s_out);
1464 6237 : if (this->last_input_socket == fsi.get()) this->last_input_socket = nullptr;
1465 6237 : this->fake_input_sockets.erase(this->fake_input_sockets.begin() + i);
1466 6237 : this->n_input_sockets--;
1467 6237 : if (this->fake_input_sockets.size() && this->last_input_socket == nullptr)
1468 0 : this->last_input_socket = this->fake_input_sockets.back().get();
1469 6237 : return pos;
1470 : }
1471 40 : i++;
1472 : }
1473 :
1474 40 : std::stringstream message;
1475 : message << "'s_out' is not bound the this task ("
1476 80 : << "'s_out.datatype' = " << type_to_string[s_out.datatype] << ", "
1477 40 : << "'s_out.name' = " << s_out.get_name() << ", "
1478 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1479 40 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ", "
1480 40 : << "'task.name' = " << this->get_name() << ", "
1481 80 : << "'task.module.name' = " << this->get_module().get_custom_name() << ").";
1482 40 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1483 40 : }
1484 : else
1485 : {
1486 45176 : std::stringstream message;
1487 : message << "This task does not have fake input socket to unbind ("
1488 90352 : << "'s_out.datatype' = " << type_to_string[s_out.datatype] << ", "
1489 45176 : << "'s_out.name' = " << s_out.get_name() << ", "
1490 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1491 45176 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ", "
1492 45176 : << "'task.name' = " << this->get_name() << ", "
1493 90352 : << "'task.module.name' = " << this->get_module().get_custom_name() << ").";
1494 45176 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1495 45176 : }
1496 : }
1497 :
1498 : size_t
1499 244 : Task::unbind(Task& t_out)
1500 : {
1501 244 : return this->unbind(*t_out.sockets.back());
1502 : }
1503 :
1504 : size_t
1505 2578 : Task::get_n_static_input_sockets() const
1506 : {
1507 2578 : size_t n = 0;
1508 9497 : for (auto& s : this->sockets)
1509 6919 : if (s->get_type() == socket_t::SIN && s->_get_dataptr() != nullptr && s->bound_socket == nullptr) n++;
1510 2578 : return n;
1511 : }
1512 :
1513 : bool
1514 0 : Task::is_stateless() const
1515 : {
1516 0 : return this->get_module().is_stateless();
1517 : }
1518 :
1519 : bool
1520 0 : Task::is_stateful() const
1521 : {
1522 0 : return this->get_module().is_stateful();
1523 : }
1524 :
1525 : bool
1526 2068 : Task::is_replicable() const
1527 : {
1528 2068 : return this->replicable;
1529 : }
1530 :
1531 : void
1532 68 : Task::set_replicability(const bool replicable)
1533 : {
1534 68 : if (replicable && !this->module->is_clonable())
1535 : {
1536 0 : std::stringstream message;
1537 : message << "The replicability of this task cannot be set to true because its corresponding module is not "
1538 0 : << "clonable (task.name = '" << this->get_name() << "', module.name = '"
1539 0 : << this->get_module().get_name() << "').";
1540 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1541 0 : }
1542 68 : this->replicable = replicable;
1543 : // this->replicable = replicable ? this->module->is_clonable() : false;
1544 68 : }
1545 :
1546 : void
1547 156032 : Task::set_outbuffers_allocated(const bool outbuffers_allocated)
1548 : {
1549 156032 : this->outbuffers_allocated = outbuffers_allocated;
1550 156032 : }
1551 :
1552 : // ==================================================================================== explicit template instantiation
1553 : template size_t
1554 : Task::create_2d_socket_in<int8_t>(const std::string&, const size_t, const size_t);
1555 : template size_t
1556 : Task::create_2d_socket_in<uint8_t>(const std::string&, const size_t, const size_t);
1557 : template size_t
1558 : Task::create_2d_socket_in<int16_t>(const std::string&, const size_t, const size_t);
1559 : template size_t
1560 : Task::create_2d_socket_in<uint16_t>(const std::string&, const size_t, const size_t);
1561 : template size_t
1562 : Task::create_2d_socket_in<int32_t>(const std::string&, const size_t, const size_t);
1563 : template size_t
1564 : Task::create_2d_socket_in<uint32_t>(const std::string&, const size_t, const size_t);
1565 : template size_t
1566 : Task::create_2d_socket_in<int64_t>(const std::string&, const size_t, const size_t);
1567 : template size_t
1568 : Task::create_2d_socket_in<uint64_t>(const std::string&, const size_t, const size_t);
1569 : template size_t
1570 : Task::create_2d_socket_in<float>(const std::string&, const size_t, const size_t);
1571 : template size_t
1572 : Task::create_2d_socket_in<double>(const std::string&, const size_t, const size_t);
1573 :
1574 : template size_t
1575 : Task::create_2d_socket_out<int8_t>(const std::string&, const size_t, const size_t, const bool);
1576 : template size_t
1577 : Task::create_2d_socket_out<uint8_t>(const std::string&, const size_t, const size_t, const bool);
1578 : template size_t
1579 : Task::create_2d_socket_out<int16_t>(const std::string&, const size_t, const size_t, const bool);
1580 : template size_t
1581 : Task::create_2d_socket_out<uint16_t>(const std::string&, const size_t, const size_t, const bool);
1582 : template size_t
1583 : Task::create_2d_socket_out<int32_t>(const std::string&, const size_t, const size_t, const bool);
1584 : template size_t
1585 : Task::create_2d_socket_out<uint32_t>(const std::string&, const size_t, const size_t, const bool);
1586 : template size_t
1587 : Task::create_2d_socket_out<int64_t>(const std::string&, const size_t, const size_t, const bool);
1588 : template size_t
1589 : Task::create_2d_socket_out<uint64_t>(const std::string&, const size_t, const size_t, const bool);
1590 : template size_t
1591 : Task::create_2d_socket_out<float>(const std::string&, const size_t, const size_t, const bool);
1592 : template size_t
1593 : Task::create_2d_socket_out<double>(const std::string&, const size_t, const size_t, const bool);
1594 :
1595 : template size_t
1596 : Task::create_2d_socket_fwd<int8_t>(const std::string&, const size_t, const size_t);
1597 : template size_t
1598 : Task::create_2d_socket_fwd<uint8_t>(const std::string&, const size_t, const size_t);
1599 : template size_t
1600 : Task::create_2d_socket_fwd<int16_t>(const std::string&, const size_t, const size_t);
1601 : template size_t
1602 : Task::create_2d_socket_fwd<uint16_t>(const std::string&, const size_t, const size_t);
1603 : template size_t
1604 : Task::create_2d_socket_fwd<int32_t>(const std::string&, const size_t, const size_t);
1605 : template size_t
1606 : Task::create_2d_socket_fwd<uint32_t>(const std::string&, const size_t, const size_t);
1607 : template size_t
1608 : Task::create_2d_socket_fwd<int64_t>(const std::string&, const size_t, const size_t);
1609 : template size_t
1610 : Task::create_2d_socket_fwd<uint64_t>(const std::string&, const size_t, const size_t);
1611 : template size_t
1612 : Task::create_2d_socket_fwd<float>(const std::string&, const size_t, const size_t);
1613 : template size_t
1614 : Task::create_2d_socket_fwd<double>(const std::string&, const size_t, const size_t);
1615 : // ==================================================================================== explicit template instantiation
|