Line data Source code
1 : #include <algorithm>
2 : #include <iomanip>
3 : #include <ios>
4 : #include <iostream>
5 : #include <rang.hpp>
6 : #include <sstream>
7 : #include <typeinfo>
8 :
9 : #include "Module/Module.hpp"
10 : #include "Module/Stateful/Set/Set.hpp"
11 : #include "Runtime/Socket/Socket.hpp"
12 : #include "Runtime/Task/Task.hpp"
13 : #include "Tools/Buffer_allocator/Buffer_allocator.hpp"
14 : #include "Tools/Exception/exception.hpp"
15 :
16 : using namespace spu;
17 : using namespace spu::runtime;
18 :
19 3805 : Task::Task(module::Module& module,
20 : const std::string& name,
21 : const bool stats,
22 : const bool fast,
23 : const bool debug,
24 3805 : const bool outbuffers_allocated)
25 3805 : : module(&module)
26 3805 : , name(name)
27 3805 : , stats(stats)
28 3805 : , fast(fast)
29 3805 : , debug(debug)
30 3805 : , outbuffers_allocated(outbuffers_allocated)
31 3805 : , debug_hex(false)
32 3805 : , replicable(module.is_clonable())
33 3805 : , debug_limit(-1)
34 3805 : , debug_precision(2)
35 3805 : , debug_frame_max(-1)
36 3805 : , codelet(
37 0 : [](module::Module& /*m*/, Task& /*t*/, const size_t /*frame_id*/) -> int
38 : {
39 0 : throw tools::unimplemented_error(__FILE__, __LINE__, __func__);
40 : return 0;
41 : })
42 3805 : , n_input_sockets(0)
43 3805 : , n_output_sockets(0)
44 3805 : , n_fwd_sockets(0)
45 3805 : , status(module.get_n_waves())
46 3805 : , n_calls(0)
47 3805 : , duration_total(std::chrono::nanoseconds(0))
48 3805 : , duration_min(std::chrono::nanoseconds(0))
49 3805 : , duration_max(std::chrono::nanoseconds(0))
50 7610 : , last_input_socket(nullptr)
51 : {
52 3805 : }
53 :
54 : Socket&
55 0 : Task::operator[](const std::string& sck_name)
56 : {
57 0 : std::string s_name = sck_name;
58 0 : s_name.erase(remove(s_name.begin(), s_name.end(), ' '), s_name.end());
59 :
60 0 : auto it = find_if(this->sockets.begin(),
61 : this->sockets.end(),
62 0 : [s_name](std::shared_ptr<runtime::Socket> s) { return s->get_name() == s_name; });
63 :
64 0 : if (it == this->sockets.end())
65 : {
66 0 : std::stringstream message;
67 0 : message << "runtime::Socket '" << s_name << "' not found for task '" << this->get_name() << "'.";
68 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
69 0 : }
70 :
71 0 : return *it->get();
72 0 : }
73 :
74 : void
75 90171 : Task::set_stats(const bool stats)
76 : {
77 90171 : this->stats = stats;
78 90171 : }
79 :
80 : void
81 96355 : Task::set_fast(const bool fast)
82 : {
83 96355 : this->fast = fast;
84 349369 : for (size_t i = 0; i < sockets.size(); i++)
85 253014 : sockets[i]->set_fast(this->fast);
86 96355 : }
87 :
88 : void
89 90079 : Task::set_debug(const bool debug)
90 : {
91 90079 : this->debug = debug;
92 90079 : }
93 :
94 : void
95 0 : Task::set_debug_hex(const bool debug_hex)
96 : {
97 0 : this->debug_hex = debug_hex;
98 0 : }
99 :
100 : void
101 90079 : Task::set_debug_limit(const uint32_t limit)
102 : {
103 90079 : this->debug_limit = (int32_t)limit;
104 90079 : }
105 :
106 : void
107 0 : Task::set_debug_precision(const uint8_t prec)
108 : {
109 0 : this->debug_precision = prec;
110 0 : }
111 :
112 : void
113 0 : Task::set_debug_frame_max(const uint32_t limit)
114 : {
115 0 : this->debug_frame_max = limit;
116 0 : }
117 :
118 : // trick to compile on the GNU compiler version 4 (where 'std::hexfloat' is unavailable)
119 : #if !defined(__clang__) && !defined(__llvm__) && defined(__GNUC__) && defined(__cplusplus) && __GNUC__ < 5
120 : namespace std
121 : {
122 : class Hexfloat
123 : {
124 : public:
125 : void message(std::ostream& os) const { os << " /!\\ 'std::hexfloat' is not supported by this compiler. /!\\ "; }
126 : };
127 : Hexfloat hexfloat;
128 : }
129 : std::ostream&
130 : operator<<(std::ostream& os, const std::Hexfloat& obj)
131 : {
132 : obj.message(os);
133 : return os;
134 : }
135 : #endif
136 :
137 : template<typename T>
138 : static inline void
139 397 : display_data(const T* data,
140 : const size_t fra_size,
141 : const size_t n_fra,
142 : const size_t n_fra_per_w,
143 : const size_t limit,
144 : const size_t max_frame,
145 : const uint8_t p,
146 : const uint8_t n_spaces,
147 : const bool hex)
148 : {
149 397 : constexpr bool is_float_type = std::is_same<float, T>::value || std::is_same<double, T>::value;
150 :
151 397 : std::ios::fmtflags f(std::cout.flags());
152 397 : if (hex)
153 : {
154 : if (is_float_type)
155 0 : std::cout << std::hexfloat << std::hex;
156 : else
157 0 : std::cout << std::hex;
158 : }
159 : else
160 397 : std::cout << std::fixed << std::setprecision(p) << std::dec;
161 :
162 397 : if (n_fra == 1 && max_frame != 0)
163 : {
164 5534 : for (size_t i = 0; i < limit; i++)
165 : {
166 5137 : if (hex)
167 0 : std::cout << (!is_float_type ? "0x" : "") << +data[i] << (i < limit - 1 ? ", " : "");
168 : else
169 5137 : std::cout << std::setw(p + 3) << +data[i] << (i < limit - 1 ? ", " : "");
170 : }
171 397 : std::cout << (limit < fra_size ? ", ..." : "");
172 397 : }
173 : else
174 : {
175 0 : std::string spaces = "#";
176 0 : for (uint8_t s = 0; s < n_spaces - 1; s++)
177 0 : spaces += " ";
178 :
179 0 : auto n_digits_dec = [](size_t f) -> size_t
180 : {
181 0 : size_t count = 0;
182 0 : while (f && ++count)
183 0 : f /= 10;
184 0 : return count;
185 : };
186 :
187 0 : const auto n_digits = n_digits_dec(max_frame);
188 0 : auto ftos = [&n_digits_dec, &n_digits](size_t f) -> std::string
189 : {
190 0 : auto n_zero = n_digits - n_digits_dec(f);
191 0 : std::string f_str = "";
192 0 : for (size_t z = 0; z < n_zero; z++)
193 0 : f_str += "0";
194 0 : f_str += std::to_string(f);
195 0 : return f_str;
196 0 : };
197 :
198 0 : const auto n_digits_w = n_digits_dec((max_frame / n_fra_per_w) == 0 ? 1 : (max_frame / n_fra_per_w));
199 0 : auto wtos = [&n_digits_dec, &n_digits_w](size_t w) -> std::string
200 : {
201 0 : auto n_zero = n_digits_w - n_digits_dec(w);
202 0 : std::string f_str = "";
203 0 : for (size_t z = 0; z < n_zero; z++)
204 0 : f_str += "0";
205 0 : f_str += std::to_string(w);
206 0 : return f_str;
207 0 : };
208 :
209 0 : for (size_t f = 0; f < max_frame; f++)
210 : {
211 0 : const auto w = f / n_fra_per_w;
212 0 : std::cout << (f >= 1 ? spaces : "") << rang::style::bold << rang::fg::gray << "f" << ftos(f + 1) << "_w"
213 0 : << wtos(w + 1) << rang::style::reset << "(";
214 :
215 0 : for (size_t i = 0; i < limit; i++)
216 : {
217 0 : if (hex)
218 0 : std::cout << (!is_float_type ? "0x" : "") << +data[f * fra_size + i] << (i < limit - 1 ? ", " : "");
219 : else
220 0 : std::cout << std::setw(p + 3) << +data[f * fra_size + i] << (i < limit - 1 ? ", " : "");
221 : }
222 0 : std::cout << (limit < fra_size ? ", ..." : "") << ")" << (f < n_fra - 1 ? ", \n" : "");
223 : }
224 :
225 0 : if (max_frame < n_fra)
226 : {
227 0 : const auto w1 = max_frame / n_fra_per_w;
228 0 : const auto w2 = n_fra / n_fra_per_w;
229 0 : std::cout << (max_frame >= 1 ? spaces : "") << rang::style::bold << rang::fg::gray << "f"
230 0 : << std::setw(n_digits) << max_frame + 1 << "_w" << std::setw(n_digits_w) << w1 + 1 << "->"
231 0 : << "f" << std::setw(n_digits) << n_fra << "_w" << std::setw(n_digits_w) << w2 + 1 << ":"
232 0 : << rang::style::reset << "(...)";
233 : }
234 0 : }
235 :
236 397 : std::cout.flags(f);
237 397 : }
238 :
239 : void
240 4178959 : Task::_exec(const int frame_id, const bool managed_memory)
241 : {
242 4178959 : const auto n_frames = this->get_module().get_n_frames();
243 4157547 : const auto n_frames_per_wave = this->get_module().get_n_frames_per_wave();
244 4115187 : const auto n_waves = this->get_module().get_n_waves();
245 4079608 : const auto n_frames_per_wave_rest = this->get_module().get_n_frames_per_wave_rest();
246 :
247 : // do not use 'this->status' because the dataptr can have been changed by the 'tools::Sequence' when using the no
248 : // copy mode
249 4054073 : int* status = this->sockets.back()->get_dataptr<int>();
250 10686193 : for (size_t w = 0; w < n_waves; w++)
251 6760424 : status[w] = (int)status_t::UNKNOWN;
252 :
253 3925769 : if ((managed_memory == false && frame_id >= 0) || (frame_id == -1 && n_frames_per_wave == n_frames) ||
254 257235 : (frame_id == 0 && n_frames_per_wave == 1) || (frame_id == 0 && n_waves > 1) ||
255 0 : (frame_id == 0 && n_frames_per_wave_rest == 0))
256 : {
257 3664087 : const auto real_frame_id = frame_id == -1 ? 0 : frame_id;
258 3664087 : const size_t w = (real_frame_id % n_frames) / n_frames_per_wave;
259 3664087 : status[w] = this->codelet(*this->module, *this, real_frame_id);
260 3985730 : }
261 : else
262 : {
263 : // save the initial dataptr of the sockets
264 700126 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
265 436752 : sockets_dataptr_init[sid] = (int8_t*)this->sockets[sid]->_get_dataptr();
266 :
267 257122 : if (frame_id > 0 && managed_memory == true && n_frames_per_wave > 1)
268 : {
269 0 : const size_t w = (frame_id % n_frames) / n_frames_per_wave;
270 0 : const size_t w_pos = frame_id % n_frames_per_wave;
271 :
272 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
273 : {
274 0 : if (sockets[sid]->get_type() == socket_t::SIN || sockets[sid]->get_type() == socket_t::SFWD)
275 0 : std::copy(
276 0 : sockets_dataptr_init[sid] + ((frame_id % n_frames) + 0) * sockets_databytes_per_frame[sid],
277 0 : sockets_dataptr_init[sid] + ((frame_id % n_frames) + 1) * sockets_databytes_per_frame[sid],
278 0 : sockets_data[sid].begin() + w_pos * sockets_databytes_per_frame[sid]);
279 0 : this->sockets[sid]->dataptr = (void*)sockets_data[sid].data();
280 : }
281 :
282 0 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
283 :
284 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
285 0 : if (sockets[sid]->get_type() == socket_t::SOUT || sockets[sid]->get_type() == socket_t::SFWD)
286 0 : std::copy(sockets_data[sid].begin() + (w_pos + 0) * sockets_databytes_per_frame[sid],
287 0 : sockets_data[sid].begin() + (w_pos + 1) * sockets_databytes_per_frame[sid],
288 0 : sockets_dataptr_init[sid] + (frame_id % n_frames) * sockets_databytes_per_frame[sid]);
289 0 : }
290 : else // if (frame_id <= 0 || n_frames_per_wave == 1)
291 : {
292 257122 : const size_t w_start = (frame_id < 0) ? 0 : frame_id % n_waves;
293 257122 : const size_t w_stop = (frame_id < 0) ? n_waves : w_start + 1;
294 :
295 257122 : size_t w = 0;
296 257122 : auto exec_status = status_t::SUCCESS;
297 2716577 : for (w = w_start; w < w_stop - 1 && exec_status != status_t::FAILURE_STOP; w++)
298 : {
299 6851679 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
300 4190162 : this->sockets[sid]->dataptr =
301 8560542 : (void*)(sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
302 :
303 2252652 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
304 2459455 : exec_status = (status_t)status[w];
305 : }
306 :
307 238147 : if (exec_status != status_t::FAILURE_STOP)
308 : {
309 254622 : if (n_frames_per_wave_rest == 0)
310 : {
311 690863 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
312 432858 : this->sockets[sid]->dataptr =
313 869072 : (void*)(sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
314 :
315 250782 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
316 : }
317 : else
318 : {
319 1 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
320 : {
321 0 : if (sockets[sid]->get_type() == socket_t::SIN || sockets[sid]->get_type() == socket_t::SFWD)
322 0 : std::copy(sockets_dataptr_init[sid] +
323 0 : w * n_frames_per_wave * sockets_databytes_per_frame[sid],
324 0 : sockets_dataptr_init[sid] + n_frames * sockets_databytes_per_frame[sid],
325 0 : sockets_data[sid].begin());
326 0 : this->sockets[sid]->dataptr = (void*)sockets_data[sid].data();
327 : }
328 :
329 0 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
330 :
331 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
332 0 : if (sockets[sid]->get_type() == socket_t::SOUT || sockets[sid]->get_type() == socket_t::SFWD)
333 0 : std::copy(
334 0 : sockets_data[sid].begin(),
335 0 : sockets_data[sid].begin() + n_frames_per_wave_rest * sockets_databytes_per_frame[sid],
336 0 : sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
337 : }
338 : }
339 : }
340 :
341 : // restore the initial dataptr of the sockets
342 686954 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
343 432390 : this->sockets[sid]->dataptr = (void*)sockets_dataptr_init[sid];
344 : }
345 4236844 : }
346 :
347 : const std::vector<int>&
348 4118753 : Task::exec(const int frame_id, const bool managed_memory)
349 : {
350 : #ifndef SPU_FAST
351 4118753 : if (this->is_fast() && !this->is_debug() && !this->is_stats())
352 : {
353 : #endif
354 4123065 : this->_exec(frame_id, managed_memory);
355 4219715 : this->n_calls++;
356 4219715 : return this->get_status();
357 : #ifndef SPU_FAST
358 : }
359 :
360 7347 : if (frame_id != -1 && (size_t)frame_id >= this->get_module().get_n_frames())
361 : {
362 0 : std::stringstream message;
363 0 : message << "'frame_id' has to be equal to '-1' or to be smaller than 'n_frames' ('frame_id' = " << frame_id
364 0 : << ", 'n_frames' = " << this->get_module().get_n_frames() << ").";
365 0 : throw tools::length_error(__FILE__, __LINE__, __func__, message.str());
366 0 : }
367 :
368 8853 : if (this->is_fast() || this->can_exec())
369 : {
370 8829 : size_t max_n_chars = 0;
371 8829 : if (this->is_debug())
372 : {
373 329 : auto n_fra = this->module->get_n_frames();
374 329 : auto n_fra_per_w = this->module->get_n_frames_per_wave();
375 :
376 : std::string module_name =
377 329 : module->get_custom_name().empty() ? module->get_name() : module->get_custom_name();
378 :
379 329 : std::cout << "# ";
380 329 : std::cout << rang::style::bold << rang::fg::green << module_name << rang::style::reset
381 329 : << "::" << rang::style::bold << rang::fg::magenta << get_name() << rang::style::reset << "(";
382 728 : for (auto i = 0; i < (int)sockets.size() - 1; i++)
383 : {
384 399 : auto& s = *sockets[i];
385 399 : auto s_type = s.get_type();
386 399 : auto n_elmts = s.get_databytes() / (size_t)s.get_datatype_size();
387 399 : std::cout << rang::style::bold << rang::fg::blue << (s_type == socket_t::SIN ? "const " : "")
388 399 : << s.get_datatype_string() << rang::style::reset << " " << s.get_name() << "["
389 798 : << (n_fra > 1 ? std::to_string(n_fra) + "x" : "") << (n_elmts / n_fra) << "]"
390 399 : << (i < (int)sockets.size() - 2 ? ", " : "");
391 :
392 399 : max_n_chars = std::max(s.get_name().size(), max_n_chars);
393 : }
394 329 : std::cout << ")" << std::endl;
395 :
396 1057 : for (auto& s : sockets)
397 : {
398 728 : auto s_type = s->get_type();
399 728 : if (s_type == socket_t::SIN || s_type == socket_t::SFWD)
400 : {
401 215 : std::string spaces;
402 343 : for (size_t ss = 0; ss < max_n_chars - s->get_name().size(); ss++)
403 128 : spaces += " ";
404 :
405 215 : auto n_elmts = s->get_databytes() / (size_t)s->get_datatype_size();
406 215 : auto fra_size = n_elmts / n_fra;
407 215 : auto limit = debug_limit != -1 ? std::min(fra_size, (size_t)debug_limit) : fra_size;
408 215 : auto max_frame = debug_frame_max != -1 ? std::min(n_fra, (size_t)debug_frame_max) : n_fra;
409 215 : auto p = debug_precision;
410 215 : auto h = debug_hex;
411 215 : std::cout << "# {IN} " << s->get_name() << spaces << " = [";
412 215 : if (s->get_datatype() == typeid(int8_t))
413 0 : display_data(s->get_dataptr<const int8_t>(),
414 : fra_size,
415 : n_fra,
416 : n_fra_per_w,
417 : limit,
418 : max_frame,
419 : p,
420 0 : (uint8_t)max_n_chars + 12,
421 : h);
422 215 : else if (s->get_datatype() == typeid(uint8_t))
423 170 : display_data(s->get_dataptr<const uint8_t>(),
424 : fra_size,
425 : n_fra,
426 : n_fra_per_w,
427 : limit,
428 : max_frame,
429 : p,
430 170 : (uint8_t)max_n_chars + 12,
431 : h);
432 45 : else if (s->get_datatype() == typeid(int16_t))
433 0 : display_data(s->get_dataptr<const int16_t>(),
434 : fra_size,
435 : n_fra,
436 : n_fra_per_w,
437 : limit,
438 : max_frame,
439 : p,
440 0 : (uint8_t)max_n_chars + 12,
441 : h);
442 45 : else if (s->get_datatype() == typeid(uint16_t))
443 0 : display_data(s->get_dataptr<const uint16_t>(),
444 : fra_size,
445 : n_fra,
446 : n_fra_per_w,
447 : limit,
448 : max_frame,
449 : p,
450 0 : (uint8_t)max_n_chars + 12,
451 : h);
452 45 : else if (s->get_datatype() == typeid(int32_t))
453 0 : display_data(s->get_dataptr<const int32_t>(),
454 : fra_size,
455 : n_fra,
456 : n_fra_per_w,
457 : limit,
458 : max_frame,
459 : p,
460 0 : (uint8_t)max_n_chars + 12,
461 : h);
462 45 : else if (s->get_datatype() == typeid(uint32_t))
463 27 : display_data(s->get_dataptr<const uint32_t>(),
464 : fra_size,
465 : n_fra,
466 : n_fra_per_w,
467 : limit,
468 : max_frame,
469 : p,
470 27 : (uint8_t)max_n_chars + 12,
471 : h);
472 18 : else if (s->get_datatype() == typeid(int64_t))
473 0 : display_data(s->get_dataptr<const int64_t>(),
474 : fra_size,
475 : n_fra,
476 : n_fra_per_w,
477 : limit,
478 : max_frame,
479 : p,
480 0 : (uint8_t)max_n_chars + 12,
481 : h);
482 18 : else if (s->get_datatype() == typeid(uint64_t))
483 18 : display_data(s->get_dataptr<const uint64_t>(),
484 : fra_size,
485 : n_fra,
486 : n_fra_per_w,
487 : limit,
488 : max_frame,
489 : p,
490 18 : (uint8_t)max_n_chars + 12,
491 : h);
492 0 : else if (s->get_datatype() == typeid(float))
493 0 : display_data(s->get_dataptr<const float>(),
494 : fra_size,
495 : n_fra,
496 : n_fra_per_w,
497 : limit,
498 : max_frame,
499 : p,
500 0 : (uint8_t)max_n_chars + 12,
501 : h);
502 0 : else if (s->get_datatype() == typeid(double))
503 0 : display_data(s->get_dataptr<const double>(),
504 : fra_size,
505 : n_fra,
506 : n_fra_per_w,
507 : limit,
508 : max_frame,
509 : p,
510 0 : (uint8_t)max_n_chars + 12,
511 : h);
512 215 : std::cout << "]" << std::endl;
513 215 : }
514 : }
515 329 : }
516 :
517 8821 : if (this->is_stats())
518 : {
519 8084 : auto t_start = std::chrono::steady_clock::now();
520 8106 : this->_exec(frame_id, managed_memory);
521 7969 : auto duration = std::chrono::steady_clock::now() - t_start;
522 :
523 7982 : this->duration_total += duration;
524 7974 : if (n_calls)
525 : {
526 7503 : this->duration_min = std::min(this->duration_min, duration);
527 7509 : this->duration_max = std::max(this->duration_max, duration);
528 : }
529 : else
530 : {
531 471 : this->duration_min = duration;
532 471 : this->duration_max = duration;
533 : }
534 : }
535 : else
536 : {
537 739 : this->_exec(frame_id, managed_memory);
538 : }
539 8748 : this->n_calls++;
540 :
541 8748 : if (this->is_debug())
542 : {
543 326 : auto n_fra = this->module->get_n_frames();
544 326 : auto n_fra_per_w = this->module->get_n_frames_per_wave();
545 1046 : for (auto& s : sockets)
546 : {
547 720 : auto s_type = s->get_type();
548 720 : if ((s_type == socket_t::SOUT) && s->get_name() != "status")
549 : {
550 182 : std::string spaces;
551 200 : for (size_t ss = 0; ss < max_n_chars - s->get_name().size(); ss++)
552 18 : spaces += " ";
553 :
554 182 : auto n_elmts = s->get_databytes() / (size_t)s->get_datatype_size();
555 182 : auto fra_size = n_elmts / n_fra;
556 182 : auto limit = debug_limit != -1 ? std::min(fra_size, (size_t)debug_limit) : fra_size;
557 182 : auto max_frame = debug_frame_max != -1 ? std::min(n_fra, (size_t)debug_frame_max) : n_fra;
558 182 : auto p = debug_precision;
559 182 : auto h = debug_hex;
560 182 : std::cout << "# {OUT} " << s->get_name() << spaces << " = [";
561 182 : if (s->get_datatype() == typeid(int8_t))
562 0 : display_data(s->get_dataptr<const int8_t>(),
563 : fra_size,
564 : n_fra,
565 : n_fra_per_w,
566 : limit,
567 : max_frame,
568 : p,
569 0 : (uint8_t)max_n_chars + 12,
570 : h);
571 182 : else if (s->get_datatype() == typeid(uint8_t))
572 146 : display_data(s->get_dataptr<const uint8_t>(),
573 : fra_size,
574 : n_fra,
575 : n_fra_per_w,
576 : limit,
577 : max_frame,
578 : p,
579 146 : (uint8_t)max_n_chars + 12,
580 : h);
581 36 : else if (s->get_datatype() == typeid(int16_t))
582 0 : display_data(s->get_dataptr<const int16_t>(),
583 : fra_size,
584 : n_fra,
585 : n_fra_per_w,
586 : limit,
587 : max_frame,
588 : p,
589 0 : (uint8_t)max_n_chars + 12,
590 : h);
591 36 : else if (s->get_datatype() == typeid(uint16_t))
592 0 : display_data(s->get_dataptr<const uint16_t>(),
593 : fra_size,
594 : n_fra,
595 : n_fra_per_w,
596 : limit,
597 : max_frame,
598 : p,
599 0 : (uint8_t)max_n_chars + 12,
600 : h);
601 36 : else if (s->get_datatype() == typeid(int32_t))
602 0 : display_data(s->get_dataptr<const int32_t>(),
603 : fra_size,
604 : n_fra,
605 : n_fra_per_w,
606 : limit,
607 : max_frame,
608 : p,
609 0 : (uint8_t)max_n_chars + 12,
610 : h);
611 36 : else if (s->get_datatype() == typeid(uint32_t))
612 18 : display_data(s->get_dataptr<const uint32_t>(),
613 : fra_size,
614 : n_fra,
615 : n_fra_per_w,
616 : limit,
617 : max_frame,
618 : p,
619 18 : (uint8_t)max_n_chars + 12,
620 : h);
621 18 : else if (s->get_datatype() == typeid(int64_t))
622 0 : display_data(s->get_dataptr<const int64_t>(),
623 : fra_size,
624 : n_fra,
625 : n_fra_per_w,
626 : limit,
627 : max_frame,
628 : p,
629 0 : (uint8_t)max_n_chars + 12,
630 : h);
631 18 : else if (s->get_datatype() == typeid(uint64_t))
632 18 : display_data(s->get_dataptr<const uint64_t>(),
633 : fra_size,
634 : n_fra,
635 : n_fra_per_w,
636 : limit,
637 : max_frame,
638 : p,
639 18 : (uint8_t)max_n_chars + 12,
640 : h);
641 0 : else if (s->get_datatype() == typeid(float))
642 0 : display_data(s->get_dataptr<const float>(),
643 : fra_size,
644 : n_fra,
645 : n_fra_per_w,
646 : limit,
647 : max_frame,
648 : p,
649 0 : (uint8_t)max_n_chars + 12,
650 : h);
651 0 : else if (s->get_datatype() == typeid(double))
652 0 : display_data(s->get_dataptr<const double>(),
653 : fra_size,
654 : n_fra,
655 : n_fra_per_w,
656 : limit,
657 : max_frame,
658 : p,
659 0 : (uint8_t)max_n_chars + 12,
660 : h);
661 182 : std::cout << "]" << std::endl;
662 182 : }
663 : }
664 326 : std::cout << "# Returned status: [";
665 : // do not use 'this->status' because the dataptr can have been changed by the 'tools::Sequence' when using
666 : // the no copy mode
667 326 : int* status = this->sockets.back()->get_dataptr<int>();
668 652 : for (size_t w = 0; w < this->get_module().get_n_waves(); w++)
669 : {
670 326 : if (status_t_to_string.count(status[w]))
671 326 : std::cout << ((w != 0) ? ", " : "") << std::dec << status[w] << " '"
672 326 : << status_t_to_string[status[w]] << "'";
673 : else
674 0 : std::cout << ((w != 0) ? ", " : "") << std::dec << status[w];
675 : }
676 326 : std::cout << "]" << std::endl;
677 326 : std::cout << "#" << std::noshowbase << std::endl;
678 : }
679 :
680 : // if (exec_status < 0)
681 : // {
682 : // std::stringstream message;
683 : // message << "'exec_status' can't be negative ('exec_status' = " << exec_status << ").";
684 : // throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
685 : // }
686 :
687 8736 : return this->get_status();
688 : }
689 : else
690 : {
691 0 : std::stringstream socs;
692 0 : socs << "'socket(s).name' = [";
693 0 : auto s = 0;
694 0 : for (size_t i = 0; i < sockets.size(); i++)
695 0 : if (sockets[i]->dataptr == nullptr) socs << (s != 0 ? ", " : "") << sockets[i]->name;
696 0 : socs << "]";
697 :
698 0 : std::stringstream message;
699 : message << "The task cannot be executed because some of the inputs/output sockets are not fed ('task.name' = "
700 0 : << this->get_name() << ", 'module.name' = " << module->get_name() << ", " << socs.str() << ").";
701 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
702 0 : }
703 : #endif /* !SPU_FAST */
704 : }
705 :
706 : template<typename T>
707 : Socket&
708 9466 : Task::create_2d_socket(const std::string& name,
709 : const size_t n_rows,
710 : const size_t n_cols,
711 : const socket_t type,
712 : const bool hack_status)
713 : {
714 9466 : if (name.empty())
715 : {
716 0 : std::stringstream message;
717 : message << "Impossible to create this socket because the name is empty ('task.name' = " << this->get_name()
718 0 : << ", 'module.name' = " << module->get_name() << ").";
719 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
720 0 : }
721 :
722 9466 : if (name == "status" && !hack_status)
723 : {
724 0 : std::stringstream message;
725 0 : message << "A socket can't be named 'status'.";
726 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
727 0 : }
728 :
729 17697 : for (auto& s : sockets)
730 8231 : if (s->get_name() == name)
731 : {
732 0 : std::stringstream message;
733 : message << "Impossible to create this socket because an other socket has the same name ('socket.name' = "
734 0 : << name << ", 'task.name' = " << this->get_name() << ", 'module.name' = " << module->get_name()
735 0 : << ").";
736 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
737 0 : }
738 :
739 17697 : for (auto s : this->sockets)
740 8231 : if (s->get_name() == "status")
741 : {
742 0 : std::stringstream message;
743 0 : message << "Creating new sockets after the 'status' socket is forbidden.";
744 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
745 0 : }
746 :
747 9466 : std::pair<size_t, size_t> databytes_per_dim = { n_rows, n_cols * sizeof(T) };
748 9466 : auto s = std::make_shared<Socket>(*this, name, typeid(T), databytes_per_dim, type, this->is_fast());
749 :
750 9466 : sockets.push_back(std::move(s));
751 :
752 9466 : this->sockets_dataptr_init.push_back(nullptr);
753 9466 : this->sockets_databytes_per_frame.push_back(sockets.back()->get_databytes() / this->get_module().get_n_frames());
754 9466 : this->sockets_data.push_back(
755 18932 : std::vector<int8_t>((this->get_module().get_n_frames_per_wave() > 1)
756 0 : ? this->sockets_databytes_per_frame.back() * this->get_module().get_n_frames_per_wave()
757 : : 0));
758 :
759 18932 : return *sockets.back();
760 9466 : }
761 :
762 : template<typename T>
763 : size_t
764 2308 : Task::create_2d_socket_in(const std::string& name, const size_t n_rows, const size_t n_cols)
765 : {
766 2308 : auto& s = create_2d_socket<T>(name, n_rows, n_cols, socket_t::SIN);
767 2308 : last_input_socket = &s;
768 :
769 2308 : this->n_input_sockets++;
770 :
771 2308 : return sockets.size() - 1;
772 : }
773 :
774 : size_t
775 749 : Task::create_2d_socket_in(const std::string& name,
776 : const size_t n_rows,
777 : const size_t n_cols,
778 : const std::type_index& datatype)
779 : {
780 749 : if (datatype == typeid(int8_t))
781 74 : return this->template create_2d_socket_in<int8_t>(name, n_rows, n_cols);
782 675 : else if (datatype == typeid(uint8_t))
783 508 : return this->template create_2d_socket_in<uint8_t>(name, n_rows, n_cols);
784 167 : else if (datatype == typeid(int16_t))
785 0 : return this->template create_2d_socket_in<int16_t>(name, n_rows, n_cols);
786 167 : else if (datatype == typeid(uint16_t))
787 0 : return this->template create_2d_socket_in<uint16_t>(name, n_rows, n_cols);
788 167 : else if (datatype == typeid(int32_t))
789 34 : return this->template create_2d_socket_in<int32_t>(name, n_rows, n_cols);
790 133 : else if (datatype == typeid(uint32_t))
791 115 : return this->template create_2d_socket_in<uint32_t>(name, n_rows, n_cols);
792 18 : else if (datatype == typeid(int64_t))
793 0 : return this->template create_2d_socket_in<int64_t>(name, n_rows, n_cols);
794 18 : else if (datatype == typeid(uint64_t))
795 18 : return this->template create_2d_socket_in<uint64_t>(name, n_rows, n_cols);
796 0 : else if (datatype == typeid(float))
797 0 : return this->template create_2d_socket_in<float>(name, n_rows, n_cols);
798 0 : else if (datatype == typeid(double))
799 0 : return this->template create_2d_socket_in<double>(name, n_rows, n_cols);
800 : else
801 : {
802 0 : std::stringstream message;
803 0 : message << "This should never happen.";
804 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
805 0 : }
806 : }
807 :
808 : size_t
809 0 : Task::create_2d_socket_in(const std::string& name, const size_t n_rows, const size_t n_cols, const datatype_t datatype)
810 : {
811 0 : switch (datatype)
812 : {
813 0 : case datatype_t::F64:
814 0 : return this->template create_2d_socket_in<double>(name, n_rows, n_cols);
815 : break;
816 0 : case datatype_t::F32:
817 0 : return this->template create_2d_socket_in<float>(name, n_rows, n_cols);
818 : break;
819 0 : case datatype_t::S64:
820 0 : return this->template create_2d_socket_in<int64_t>(name, n_rows, n_cols);
821 : break;
822 0 : case datatype_t::S32:
823 0 : return this->template create_2d_socket_in<int32_t>(name, n_rows, n_cols);
824 : break;
825 0 : case datatype_t::S16:
826 0 : return this->template create_2d_socket_in<int16_t>(name, n_rows, n_cols);
827 : break;
828 0 : case datatype_t::S8:
829 0 : return this->template create_2d_socket_in<int8_t>(name, n_rows, n_cols);
830 : break;
831 0 : case datatype_t::U64:
832 0 : return this->template create_2d_socket_in<uint64_t>(name, n_rows, n_cols);
833 : break;
834 0 : case datatype_t::U32:
835 0 : return this->template create_2d_socket_in<uint32_t>(name, n_rows, n_cols);
836 : break;
837 0 : case datatype_t::U16:
838 0 : return this->template create_2d_socket_in<uint16_t>(name, n_rows, n_cols);
839 : break;
840 0 : case datatype_t::U8:
841 0 : return this->template create_2d_socket_in<uint8_t>(name, n_rows, n_cols);
842 : break;
843 0 : default:
844 : {
845 0 : std::stringstream message;
846 0 : message << "This should never happen.";
847 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
848 : break;
849 0 : }
850 : }
851 : }
852 :
853 : template<typename T>
854 : size_t
855 5975 : Task::create_2d_socket_out(const std::string& name, const size_t n_rows, const size_t n_cols, const bool hack_status)
856 : {
857 5975 : create_2d_socket<T>(name, n_rows, n_cols, socket_t::SOUT, hack_status);
858 5975 : this->n_output_sockets++;
859 :
860 5975 : return sockets.size() - 1;
861 : }
862 :
863 : size_t
864 742 : Task::create_2d_socket_out(const std::string& name,
865 : const size_t n_rows,
866 : const size_t n_cols,
867 : const std::type_index& datatype,
868 : const bool hack_status)
869 : {
870 742 : if (datatype == typeid(int8_t))
871 67 : return this->template create_2d_socket_out<int8_t>(name, n_rows, n_cols, hack_status);
872 675 : else if (datatype == typeid(uint8_t))
873 508 : return this->template create_2d_socket_out<uint8_t>(name, n_rows, n_cols, hack_status);
874 167 : else if (datatype == typeid(int16_t))
875 0 : return this->template create_2d_socket_out<int16_t>(name, n_rows, n_cols, hack_status);
876 167 : else if (datatype == typeid(uint16_t))
877 0 : return this->template create_2d_socket_out<uint16_t>(name, n_rows, n_cols, hack_status);
878 167 : else if (datatype == typeid(int32_t))
879 34 : return this->template create_2d_socket_out<int32_t>(name, n_rows, n_cols, hack_status);
880 133 : else if (datatype == typeid(uint32_t))
881 115 : return this->template create_2d_socket_out<uint32_t>(name, n_rows, n_cols, hack_status);
882 18 : else if (datatype == typeid(int64_t))
883 0 : return this->template create_2d_socket_out<int64_t>(name, n_rows, n_cols, hack_status);
884 18 : else if (datatype == typeid(uint64_t))
885 18 : return this->template create_2d_socket_out<uint64_t>(name, n_rows, n_cols, hack_status);
886 0 : else if (datatype == typeid(float))
887 0 : return this->template create_2d_socket_out<float>(name, n_rows, n_cols, hack_status);
888 0 : else if (datatype == typeid(double))
889 0 : return this->template create_2d_socket_out<double>(name, n_rows, n_cols, hack_status);
890 : else
891 : {
892 0 : std::stringstream message;
893 0 : message << "This should never happen.";
894 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
895 0 : }
896 : }
897 :
898 : size_t
899 0 : Task::create_2d_socket_out(const std::string& name,
900 : const size_t n_rows,
901 : const size_t n_cols,
902 : const datatype_t datatype,
903 : const bool hack_status)
904 : {
905 0 : switch (datatype)
906 : {
907 0 : case datatype_t::F64:
908 0 : return this->template create_2d_socket_out<double>(name, n_rows, n_cols, hack_status);
909 : break;
910 0 : case datatype_t::F32:
911 0 : return this->template create_2d_socket_out<float>(name, n_rows, n_cols, hack_status);
912 : break;
913 0 : case datatype_t::S64:
914 0 : return this->template create_2d_socket_out<int64_t>(name, n_rows, n_cols, hack_status);
915 : break;
916 0 : case datatype_t::S32:
917 0 : return this->template create_2d_socket_out<int32_t>(name, n_rows, n_cols, hack_status);
918 : break;
919 0 : case datatype_t::S16:
920 0 : return this->template create_2d_socket_out<int16_t>(name, n_rows, n_cols, hack_status);
921 : break;
922 0 : case datatype_t::S8:
923 0 : return this->template create_2d_socket_out<int8_t>(name, n_rows, n_cols, hack_status);
924 : break;
925 0 : case datatype_t::U64:
926 0 : return this->template create_2d_socket_out<uint64_t>(name, n_rows, n_cols, hack_status);
927 : break;
928 0 : case datatype_t::U32:
929 0 : return this->template create_2d_socket_out<uint32_t>(name, n_rows, n_cols, hack_status);
930 : break;
931 0 : case datatype_t::U16:
932 0 : return this->template create_2d_socket_out<uint16_t>(name, n_rows, n_cols, hack_status);
933 : break;
934 0 : case datatype_t::U8:
935 0 : return this->template create_2d_socket_out<uint8_t>(name, n_rows, n_cols, hack_status);
936 : break;
937 0 : default:
938 : {
939 0 : std::stringstream message;
940 0 : message << "This should never happen.";
941 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
942 : break;
943 0 : }
944 : }
945 : }
946 :
947 : template<typename T>
948 : size_t
949 1183 : Task::create_2d_socket_fwd(const std::string& name, const size_t n_rows, const size_t n_cols)
950 : {
951 1183 : auto& s = create_2d_socket<T>(name, n_rows, n_cols, socket_t::SFWD);
952 1183 : last_input_socket = &s;
953 :
954 1183 : this->n_fwd_sockets++;
955 :
956 1183 : return sockets.size() - 1;
957 : }
958 :
959 : size_t
960 0 : Task::create_2d_socket_fwd(const std::string& name,
961 : const size_t n_rows,
962 : const size_t n_cols,
963 : const std::type_index& datatype)
964 : {
965 0 : if (datatype == typeid(int8_t))
966 0 : return this->template create_2d_socket_fwd<int8_t>(name, n_rows, n_cols);
967 0 : else if (datatype == typeid(uint8_t))
968 0 : return this->template create_2d_socket_fwd<uint8_t>(name, n_rows, n_cols);
969 0 : else if (datatype == typeid(int16_t))
970 0 : return this->template create_2d_socket_fwd<int16_t>(name, n_rows, n_cols);
971 0 : else if (datatype == typeid(uint16_t))
972 0 : return this->template create_2d_socket_fwd<uint16_t>(name, n_rows, n_cols);
973 0 : else if (datatype == typeid(int32_t))
974 0 : return this->template create_2d_socket_fwd<int32_t>(name, n_rows, n_cols);
975 0 : else if (datatype == typeid(uint32_t))
976 0 : return this->template create_2d_socket_fwd<uint32_t>(name, n_rows, n_cols);
977 0 : else if (datatype == typeid(int64_t))
978 0 : return this->template create_2d_socket_fwd<int64_t>(name, n_rows, n_cols);
979 0 : else if (datatype == typeid(uint64_t))
980 0 : return this->template create_2d_socket_fwd<uint64_t>(name, n_rows, n_cols);
981 0 : else if (datatype == typeid(float))
982 0 : return this->template create_2d_socket_fwd<float>(name, n_rows, n_cols);
983 0 : else if (datatype == typeid(double))
984 0 : return this->template create_2d_socket_fwd<double>(name, n_rows, n_cols);
985 : else
986 : {
987 0 : std::stringstream message;
988 0 : message << "This should never happen.";
989 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
990 0 : }
991 : }
992 :
993 : size_t
994 0 : Task::create_2d_socket_fwd(const std::string& name, const size_t n_rows, const size_t n_cols, const datatype_t datatype)
995 : {
996 0 : switch (datatype)
997 : {
998 0 : case datatype_t::F64:
999 0 : return this->template create_2d_socket_fwd<double>(name, n_rows, n_cols);
1000 : break;
1001 0 : case datatype_t::F32:
1002 0 : return this->template create_2d_socket_fwd<float>(name, n_rows, n_cols);
1003 : break;
1004 0 : case datatype_t::S64:
1005 0 : return this->template create_2d_socket_fwd<int64_t>(name, n_rows, n_cols);
1006 : break;
1007 0 : case datatype_t::S32:
1008 0 : return this->template create_2d_socket_fwd<int32_t>(name, n_rows, n_cols);
1009 : break;
1010 0 : case datatype_t::S16:
1011 0 : return this->template create_2d_socket_fwd<int16_t>(name, n_rows, n_cols);
1012 : break;
1013 0 : case datatype_t::S8:
1014 0 : return this->template create_2d_socket_fwd<int8_t>(name, n_rows, n_cols);
1015 : break;
1016 0 : case datatype_t::U64:
1017 0 : return this->template create_2d_socket_fwd<uint64_t>(name, n_rows, n_cols);
1018 : break;
1019 0 : case datatype_t::U32:
1020 0 : return this->template create_2d_socket_fwd<uint32_t>(name, n_rows, n_cols);
1021 : break;
1022 0 : case datatype_t::U16:
1023 0 : return this->template create_2d_socket_fwd<uint16_t>(name, n_rows, n_cols);
1024 : break;
1025 0 : case datatype_t::U8:
1026 0 : return this->template create_2d_socket_fwd<uint8_t>(name, n_rows, n_cols);
1027 : break;
1028 0 : default:
1029 : {
1030 0 : std::stringstream message;
1031 0 : message << "This should never happen.";
1032 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1033 : break;
1034 0 : }
1035 : }
1036 : }
1037 :
1038 : void
1039 3805 : Task::create_codelet(std::function<int(module::Module& m, Task& t, const size_t frame_id)>& codelet)
1040 : {
1041 3805 : this->codelet = codelet;
1042 :
1043 : // create automatically a socket that contains the status of the task
1044 3805 : const bool hack_status = true;
1045 3805 : auto s = this->template create_2d_socket_out<int>("status", 1, this->get_module().get_n_waves(), hack_status);
1046 3805 : this->sockets[s]->dataptr = (void*)this->status.data();
1047 :
1048 3805 : if (tools::Buffer_allocator::get_task_autoalloc()) this->allocate_outbuffers();
1049 3805 : }
1050 :
1051 : void
1052 98510 : Task::update_n_frames(const size_t old_n_frames, const size_t new_n_frames)
1053 : {
1054 353544 : for (auto& s : this->sockets)
1055 : {
1056 255034 : if (s->get_name() == "status")
1057 : {
1058 98510 : if (this->get_module().get_n_waves() * sizeof(int) != s->get_databytes())
1059 : {
1060 86520 : s->set_databytes(this->get_module().get_n_waves() * sizeof(int));
1061 86520 : this->status.resize(this->get_module().get_n_waves());
1062 86520 : s->set_dataptr((void*)this->status.data());
1063 : }
1064 : }
1065 : else
1066 : {
1067 156524 : const auto old_databytes = s->get_databytes();
1068 156524 : const auto new_databytes = (old_databytes / old_n_frames) * new_n_frames;
1069 156524 : s->set_databytes(new_databytes);
1070 :
1071 156524 : const size_t prev_n_rows_wo_nfra = s->get_n_rows() / old_n_frames;
1072 156524 : s->set_n_rows(prev_n_rows_wo_nfra * new_n_frames);
1073 :
1074 156524 : if (s->get_type() == socket_t::SOUT)
1075 : {
1076 62014 : s->resize_out_buffer(new_databytes);
1077 : }
1078 : }
1079 : }
1080 98510 : }
1081 :
1082 : void
1083 23980 : Task::update_n_frames_per_wave(const size_t /*old_n_frames_per_wave*/, const size_t new_n_frames_per_wave)
1084 : {
1085 23980 : size_t s_id = 0;
1086 96944 : for (auto& s : this->sockets)
1087 : {
1088 72964 : if (s->get_name() == "status")
1089 : {
1090 23980 : if (this->get_module().get_n_waves() * sizeof(int) != s->get_databytes())
1091 : {
1092 11990 : s->set_databytes(this->get_module().get_n_waves() * sizeof(int));
1093 11990 : this->status.resize(this->get_module().get_n_waves());
1094 11990 : s->set_dataptr((void*)this->status.data());
1095 : }
1096 : }
1097 : else
1098 : {
1099 73476 : this->sockets_data[s_id].resize(
1100 24492 : (new_n_frames_per_wave > 1) ? this->sockets_databytes_per_frame[s_id] * new_n_frames_per_wave : 0);
1101 : }
1102 72964 : s_id++;
1103 : }
1104 23980 : }
1105 :
1106 : void
1107 59743 : Task::allocate_outbuffers()
1108 : {
1109 59743 : if (!this->is_outbuffers_allocated())
1110 : {
1111 : std::function<void(Socket * socket, void* data_ptr)> spread_dataptr =
1112 55671 : [&spread_dataptr](Socket* socket, void* data_ptr)
1113 : {
1114 94055 : for (auto bound_socket : socket->get_bound_sockets())
1115 : {
1116 44468 : if (bound_socket->get_type() == socket_t::SIN)
1117 : {
1118 38384 : bound_socket->set_dataptr(data_ptr);
1119 : }
1120 6084 : else if (bound_socket->get_type() == socket_t::SFWD)
1121 : {
1122 6084 : bound_socket->set_dataptr(data_ptr);
1123 6084 : spread_dataptr(bound_socket, data_ptr);
1124 : }
1125 : else
1126 : {
1127 0 : std::stringstream message;
1128 0 : message << "bound socket is of type SOUT, but should be SIN or SFWD";
1129 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1130 0 : }
1131 : }
1132 103007 : };
1133 203456 : for (auto s : this->sockets)
1134 : {
1135 150036 : if (s->get_type() == socket_t::SOUT && s->get_name() != "status")
1136 : {
1137 43552 : if (s->get_dataptr() == nullptr)
1138 : {
1139 43503 : s->allocate_buffer();
1140 43503 : spread_dataptr(s.get(), s->get_dataptr());
1141 : }
1142 : }
1143 150036 : }
1144 53420 : this->set_outbuffers_allocated(true);
1145 53420 : }
1146 59743 : }
1147 : void
1148 14761 : Task::deallocate_outbuffers()
1149 : {
1150 14761 : if (!this->is_outbuffers_allocated())
1151 : {
1152 0 : std::stringstream message;
1153 : message << "Task out sockets buffers are not allocated"
1154 0 : << ", task name : " << this->get_name();
1155 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1156 0 : }
1157 13796 : std::function<void(Socket * socket)> spread_nullptr = [&spread_nullptr](Socket* socket)
1158 : {
1159 22335 : for (auto bound_socket : socket->get_bound_sockets())
1160 : {
1161 9427 : if (bound_socket->get_type() == socket_t::SIN)
1162 : {
1163 8539 : bound_socket->set_dataptr(nullptr);
1164 : }
1165 888 : else if (bound_socket->get_type() == socket_t::SFWD)
1166 : {
1167 888 : bound_socket->set_dataptr(nullptr);
1168 888 : spread_nullptr(bound_socket);
1169 : }
1170 0 : else if (dynamic_cast<const module::Set*>(&bound_socket->get_task().get_module()))
1171 : {
1172 : // hack: for set that bind SOUT to SOUT for perf
1173 0 : bound_socket->set_dataptr(nullptr);
1174 : }
1175 : else
1176 : {
1177 0 : std::stringstream message;
1178 0 : message << "bound socket is of type SOUT, but should be SIN or SFWD";
1179 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1180 0 : }
1181 : }
1182 27669 : };
1183 55564 : for (auto s : this->sockets)
1184 : {
1185 40803 : if (s->get_type() == socket_t::SOUT && s->get_name() != "status")
1186 : {
1187 12020 : if (s->get_dataptr() != nullptr)
1188 : {
1189 12020 : s->deallocate_buffer();
1190 12020 : spread_nullptr(s.get());
1191 : }
1192 : }
1193 40803 : }
1194 14761 : this->set_outbuffers_allocated(false);
1195 14761 : }
1196 :
1197 : bool
1198 762 : Task::can_exec() const
1199 : {
1200 2959 : for (size_t i = 0; i < sockets.size(); i++)
1201 2175 : if (sockets[i]->dataptr == nullptr) return false;
1202 736 : return true;
1203 : }
1204 :
1205 : std::chrono::nanoseconds
1206 11676 : Task::get_duration_total() const
1207 : {
1208 11676 : return this->duration_total;
1209 : }
1210 :
1211 : std::chrono::nanoseconds
1212 1302 : Task::get_duration_avg() const
1213 : {
1214 1302 : return this->duration_total / this->n_calls;
1215 : }
1216 :
1217 : std::chrono::nanoseconds
1218 1532 : Task::get_duration_min() const
1219 : {
1220 1532 : return this->duration_min;
1221 : }
1222 :
1223 : std::chrono::nanoseconds
1224 1532 : Task::get_duration_max() const
1225 : {
1226 1532 : return this->duration_max;
1227 : }
1228 :
1229 : const std::vector<std::string>&
1230 385 : Task::get_timers_name() const
1231 : {
1232 385 : return this->timers_name;
1233 : }
1234 :
1235 : const std::vector<uint32_t>&
1236 59 : Task::get_timers_n_calls() const
1237 : {
1238 59 : return this->timers_n_calls;
1239 : }
1240 :
1241 : const std::vector<std::chrono::nanoseconds>&
1242 59 : Task::get_timers_total() const
1243 : {
1244 59 : return this->timers_total;
1245 : }
1246 :
1247 : const std::vector<std::chrono::nanoseconds>&
1248 59 : Task::get_timers_min() const
1249 : {
1250 59 : return this->timers_min;
1251 : }
1252 :
1253 : const std::vector<std::chrono::nanoseconds>&
1254 59 : Task::get_timers_max() const
1255 : {
1256 59 : return this->timers_max;
1257 : }
1258 :
1259 : size_t
1260 3189 : Task::get_n_input_sockets() const
1261 : {
1262 3189 : return this->n_input_sockets;
1263 : }
1264 :
1265 : size_t
1266 0 : Task::get_n_output_sockets() const
1267 : {
1268 0 : return this->n_output_sockets;
1269 : }
1270 :
1271 : size_t
1272 4330 : Task::get_n_fwd_sockets() const
1273 : {
1274 4330 : return this->n_fwd_sockets;
1275 : }
1276 :
1277 : void
1278 0 : Task::register_timer(const std::string& name)
1279 : {
1280 0 : this->timers_name.push_back(name);
1281 0 : this->timers_n_calls.push_back(0);
1282 0 : this->timers_total.push_back(std::chrono::nanoseconds(0));
1283 0 : this->timers_max.push_back(std::chrono::nanoseconds(0));
1284 0 : this->timers_min.push_back(std::chrono::nanoseconds(0));
1285 0 : }
1286 :
1287 : void
1288 90171 : Task::reset()
1289 : {
1290 90171 : this->n_calls = 0;
1291 90171 : this->duration_total = std::chrono::nanoseconds(0);
1292 90171 : this->duration_min = std::chrono::nanoseconds(0);
1293 90171 : this->duration_max = std::chrono::nanoseconds(0);
1294 :
1295 90171 : for (auto& x : this->timers_n_calls)
1296 0 : x = 0;
1297 90171 : for (auto& x : this->timers_total)
1298 0 : x = std::chrono::nanoseconds(0);
1299 90171 : for (auto& x : this->timers_min)
1300 0 : x = std::chrono::nanoseconds(0);
1301 90171 : for (auto& x : this->timers_max)
1302 0 : x = std::chrono::nanoseconds(0);
1303 90171 : }
1304 :
1305 : Task*
1306 83152 : Task::clone() const
1307 : {
1308 83152 : Task* t = new Task(*this);
1309 83152 : t->sockets.clear();
1310 83152 : t->last_input_socket = nullptr;
1311 83152 : t->fake_input_sockets.clear();
1312 83152 : t->set_outbuffers_allocated(false);
1313 :
1314 300314 : for (auto s : this->sockets)
1315 : {
1316 217162 : void* dataptr = nullptr;
1317 217162 : if (s->get_type() == socket_t::SOUT)
1318 : {
1319 136362 : if (s->get_name() == "status")
1320 : {
1321 83152 : dataptr = (void*)t->status.data();
1322 : }
1323 : }
1324 80800 : else if (s->get_type() == socket_t::SIN || s->get_type() == socket_t::SFWD)
1325 80800 : dataptr = s->_get_dataptr();
1326 :
1327 : // No need to allocate memory when cloning
1328 217162 : const std::pair<size_t, size_t> databytes_per_dim = { s->get_n_rows(), s->get_databytes() / s->get_n_rows() };
1329 : auto s_new = std::shared_ptr<Socket>(
1330 217162 : new Socket(*t, s->get_name(), s->get_datatype(), databytes_per_dim, s->get_type(), s->is_fast(), dataptr));
1331 217162 : t->sockets.push_back(s_new);
1332 :
1333 217162 : if (s_new->get_type() == socket_t::SIN || s_new->get_type() == socket_t::SFWD)
1334 80800 : t->last_input_socket = s_new.get();
1335 217162 : }
1336 :
1337 83152 : if (tools::Buffer_allocator::get_task_autoalloc()) t->allocate_outbuffers();
1338 :
1339 83152 : return t;
1340 : }
1341 :
1342 : void
1343 10446 : Task::_bind(Socket& s_out, const int priority)
1344 : {
1345 : // check if the 's_out' socket is already used for an other fake input socket
1346 10446 : bool already_bound = false;
1347 10446 : for (auto& fsi : this->fake_input_sockets)
1348 0 : if (&fsi->get_bound_socket() == &s_out)
1349 : {
1350 0 : already_bound = true;
1351 0 : break;
1352 : }
1353 :
1354 : // check if the 's_out' socket is already used for an other read input/fwd socket
1355 10446 : if (!already_bound)
1356 31855 : for (auto& s : this->sockets)
1357 21409 : if (s->get_type() == socket_t::SIN || s->get_type() == socket_t::SFWD)
1358 : {
1359 : try // because 's->get_bound_socket()' can throw if s->bound_socket == 'nullptr'
1360 : {
1361 1899 : if (&s->get_bound_socket() == &s_out)
1362 : {
1363 0 : already_bound = true;
1364 0 : break;
1365 : }
1366 : }
1367 863 : catch (...)
1368 : {
1369 863 : }
1370 : }
1371 :
1372 : // if the s_out socket is not already bound, then create a new fake input socket
1373 10446 : if (!already_bound)
1374 : {
1375 10446 : this->fake_input_sockets.push_back(
1376 20892 : std::shared_ptr<Socket>(new Socket(*this,
1377 20892 : "fake" + std::to_string(this->fake_input_sockets.size()),
1378 10446 : s_out.get_datatype(),
1379 10446 : s_out.get_databytes(),
1380 : socket_t::SIN,
1381 10446 : this->is_fast())));
1382 10446 : this->fake_input_sockets.back()->_bind(s_out, priority);
1383 10446 : this->last_input_socket = this->fake_input_sockets.back().get();
1384 10446 : this->n_input_sockets++;
1385 : }
1386 10446 : }
1387 :
1388 : void
1389 0 : Task::bind(Socket& s_out, const int priority)
1390 : {
1391 : #ifdef SPU_SHOW_DEPRECATED
1392 : std::clog << rang::tag::warning << "Deprecated: 'Task::bind()' should be replaced by 'Task::operator='."
1393 : << std::endl;
1394 : #ifdef SPU_STACKTRACE
1395 : #ifdef SPU_COLORS
1396 : bool enable_color = true;
1397 : #else
1398 : bool enable_color = false;
1399 : #endif
1400 : cpptrace::generate_trace().print(std::clog, enable_color);
1401 : #endif
1402 : #endif
1403 0 : this->_bind(s_out, priority);
1404 0 : }
1405 :
1406 : void
1407 0 : Task::_bind(Task& t_out, const int priority)
1408 : {
1409 0 : this->_bind(*t_out.sockets.back(), priority);
1410 0 : }
1411 :
1412 : void
1413 0 : Task::bind(Task& t_out, const int priority)
1414 : {
1415 : #ifdef SPU_SHOW_DEPRECATED
1416 : std::clog << rang::tag::warning << "Deprecated: 'Task::bind()' should be replaced by 'Task::operator='."
1417 : << std::endl;
1418 : #ifdef SPU_STACKTRACE
1419 : #ifdef SPU_COLORS
1420 : bool enable_color = true;
1421 : #else
1422 : bool enable_color = false;
1423 : #endif
1424 : cpptrace::generate_trace().print(std::clog, enable_color);
1425 : #endif
1426 : #endif
1427 0 : this->_bind(t_out, priority);
1428 0 : }
1429 :
1430 : void
1431 8283 : Task::operator=(Socket& s_out)
1432 : {
1433 : #ifndef SPU_FAST
1434 8283 : if (s_out.get_type() == socket_t::SOUT || s_out.get_type() == socket_t::SFWD)
1435 : #endif
1436 8283 : this->_bind(s_out);
1437 : #ifndef SPU_FAST
1438 : else
1439 : {
1440 0 : std::stringstream message;
1441 : message << "'s_out' should be and output socket ("
1442 0 : << "'s_out.datatype' = " << type_to_string[s_out.get_datatype()] << ", "
1443 0 : << "'s_out.name' = " << s_out.get_name() << ", "
1444 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1445 0 : << "'s_out.type' = " << (s_out.get_type() == socket_t::SIN ? "SIN" : "SOUT") << ", "
1446 0 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ").";
1447 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1448 0 : }
1449 : #endif
1450 8283 : }
1451 :
1452 : void
1453 342 : Task::operator=(Task& t_out)
1454 : {
1455 342 : (*this) = *t_out.sockets.back();
1456 342 : }
1457 :
1458 : size_t
1459 58086 : Task::unbind(Socket& s_out)
1460 : {
1461 58086 : if (this->fake_input_sockets.size())
1462 : {
1463 7248 : size_t i = 0;
1464 7314 : for (auto& fsi : this->fake_input_sockets)
1465 : {
1466 7248 : if (&fsi->get_bound_socket() == &s_out)
1467 : {
1468 7182 : const auto pos = fsi->unbind(s_out);
1469 7182 : if (this->last_input_socket == fsi.get()) this->last_input_socket = nullptr;
1470 7182 : this->fake_input_sockets.erase(this->fake_input_sockets.begin() + i);
1471 7182 : this->n_input_sockets--;
1472 7182 : if (this->fake_input_sockets.size() && this->last_input_socket == nullptr)
1473 0 : this->last_input_socket = this->fake_input_sockets.back().get();
1474 7182 : return pos;
1475 : }
1476 66 : i++;
1477 : }
1478 :
1479 66 : std::stringstream message;
1480 : message << "'s_out' is not bound the this task ("
1481 132 : << "'s_out.datatype' = " << type_to_string[s_out.datatype] << ", "
1482 66 : << "'s_out.name' = " << s_out.get_name() << ", "
1483 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1484 66 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ", "
1485 66 : << "'task.name' = " << this->get_name() << ", "
1486 132 : << "'task.module.name' = " << this->get_module().get_custom_name() << ").";
1487 66 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1488 66 : }
1489 : else
1490 : {
1491 50838 : std::stringstream message;
1492 : message << "This task does not have fake input socket to unbind ("
1493 101676 : << "'s_out.datatype' = " << type_to_string[s_out.datatype] << ", "
1494 50838 : << "'s_out.name' = " << s_out.get_name() << ", "
1495 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1496 50838 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ", "
1497 50838 : << "'task.name' = " << this->get_name() << ", "
1498 101676 : << "'task.module.name' = " << this->get_module().get_custom_name() << ").";
1499 50838 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1500 50838 : }
1501 : }
1502 :
1503 : size_t
1504 278 : Task::unbind(Task& t_out)
1505 : {
1506 278 : return this->unbind(*t_out.sockets.back());
1507 : }
1508 :
1509 : size_t
1510 3125 : Task::get_n_static_input_sockets() const
1511 : {
1512 3125 : size_t n = 0;
1513 11579 : for (auto& s : this->sockets)
1514 8454 : if (s->get_type() == socket_t::SIN && s->_get_dataptr() != nullptr && s->bound_socket == nullptr) n++;
1515 3125 : return n;
1516 : }
1517 :
1518 : bool
1519 0 : Task::is_stateless() const
1520 : {
1521 0 : return this->get_module().is_stateless();
1522 : }
1523 :
1524 : bool
1525 0 : Task::is_stateful() const
1526 : {
1527 0 : return this->get_module().is_stateful();
1528 : }
1529 :
1530 : bool
1531 2468 : Task::is_replicable() const
1532 : {
1533 2468 : return this->replicable;
1534 : }
1535 :
1536 : void
1537 75 : Task::set_replicability(const bool replicable)
1538 : {
1539 75 : if (replicable && !this->module->is_clonable())
1540 : {
1541 0 : std::stringstream message;
1542 : message << "The replicability of this task cannot be set to true because its corresponding module is not "
1543 0 : << "clonable (task.name = '" << this->get_name() << "', module.name = '"
1544 0 : << this->get_module().get_name() << "').";
1545 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1546 0 : }
1547 75 : this->replicable = replicable;
1548 : // this->replicable = replicable ? this->module->is_clonable() : false;
1549 75 : }
1550 :
1551 : void
1552 151333 : Task::set_outbuffers_allocated(const bool outbuffers_allocated)
1553 : {
1554 151333 : this->outbuffers_allocated = outbuffers_allocated;
1555 151333 : }
1556 :
1557 : // ==================================================================================== explicit template instantiation
1558 : template size_t
1559 : Task::create_2d_socket_in<int8_t>(const std::string&, const size_t, const size_t);
1560 : template size_t
1561 : Task::create_2d_socket_in<uint8_t>(const std::string&, const size_t, const size_t);
1562 : template size_t
1563 : Task::create_2d_socket_in<int16_t>(const std::string&, const size_t, const size_t);
1564 : template size_t
1565 : Task::create_2d_socket_in<uint16_t>(const std::string&, const size_t, const size_t);
1566 : template size_t
1567 : Task::create_2d_socket_in<int32_t>(const std::string&, const size_t, const size_t);
1568 : template size_t
1569 : Task::create_2d_socket_in<uint32_t>(const std::string&, const size_t, const size_t);
1570 : template size_t
1571 : Task::create_2d_socket_in<int64_t>(const std::string&, const size_t, const size_t);
1572 : template size_t
1573 : Task::create_2d_socket_in<uint64_t>(const std::string&, const size_t, const size_t);
1574 : template size_t
1575 : Task::create_2d_socket_in<float>(const std::string&, const size_t, const size_t);
1576 : template size_t
1577 : Task::create_2d_socket_in<double>(const std::string&, const size_t, const size_t);
1578 :
1579 : template size_t
1580 : Task::create_2d_socket_out<int8_t>(const std::string&, const size_t, const size_t, const bool);
1581 : template size_t
1582 : Task::create_2d_socket_out<uint8_t>(const std::string&, const size_t, const size_t, const bool);
1583 : template size_t
1584 : Task::create_2d_socket_out<int16_t>(const std::string&, const size_t, const size_t, const bool);
1585 : template size_t
1586 : Task::create_2d_socket_out<uint16_t>(const std::string&, const size_t, const size_t, const bool);
1587 : template size_t
1588 : Task::create_2d_socket_out<int32_t>(const std::string&, const size_t, const size_t, const bool);
1589 : template size_t
1590 : Task::create_2d_socket_out<uint32_t>(const std::string&, const size_t, const size_t, const bool);
1591 : template size_t
1592 : Task::create_2d_socket_out<int64_t>(const std::string&, const size_t, const size_t, const bool);
1593 : template size_t
1594 : Task::create_2d_socket_out<uint64_t>(const std::string&, const size_t, const size_t, const bool);
1595 : template size_t
1596 : Task::create_2d_socket_out<float>(const std::string&, const size_t, const size_t, const bool);
1597 : template size_t
1598 : Task::create_2d_socket_out<double>(const std::string&, const size_t, const size_t, const bool);
1599 :
1600 : template size_t
1601 : Task::create_2d_socket_fwd<int8_t>(const std::string&, const size_t, const size_t);
1602 : template size_t
1603 : Task::create_2d_socket_fwd<uint8_t>(const std::string&, const size_t, const size_t);
1604 : template size_t
1605 : Task::create_2d_socket_fwd<int16_t>(const std::string&, const size_t, const size_t);
1606 : template size_t
1607 : Task::create_2d_socket_fwd<uint16_t>(const std::string&, const size_t, const size_t);
1608 : template size_t
1609 : Task::create_2d_socket_fwd<int32_t>(const std::string&, const size_t, const size_t);
1610 : template size_t
1611 : Task::create_2d_socket_fwd<uint32_t>(const std::string&, const size_t, const size_t);
1612 : template size_t
1613 : Task::create_2d_socket_fwd<int64_t>(const std::string&, const size_t, const size_t);
1614 : template size_t
1615 : Task::create_2d_socket_fwd<uint64_t>(const std::string&, const size_t, const size_t);
1616 : template size_t
1617 : Task::create_2d_socket_fwd<float>(const std::string&, const size_t, const size_t);
1618 : template size_t
1619 : Task::create_2d_socket_fwd<double>(const std::string&, const size_t, const size_t);
1620 : // ==================================================================================== explicit template instantiation
|