Line data Source code
1 : #include <algorithm>
2 : #include <iomanip>
3 : #include <ios>
4 : #include <iostream>
5 : #include <rang.hpp>
6 : #include <sstream>
7 : #include <typeinfo>
8 :
9 : #include "Module/Module.hpp"
10 : #include "Module/Stateful/Set/Set.hpp"
11 : #include "Runtime/Socket/Socket.hpp"
12 : #include "Runtime/Task/Task.hpp"
13 : #include "Tools/Buffer_allocator/Buffer_allocator.hpp"
14 : #include "Tools/Exception/exception.hpp"
15 :
16 : using namespace spu;
17 : using namespace spu::runtime;
18 :
19 3805 : Task::Task(module::Module& module,
20 : const std::string& name,
21 : const bool stats,
22 : const bool fast,
23 : const bool debug,
24 3805 : const bool outbuffers_allocated)
25 3805 : : module(&module)
26 3805 : , name(name)
27 3805 : , stats(stats)
28 3805 : , fast(fast)
29 3805 : , debug(debug)
30 3805 : , outbuffers_allocated(outbuffers_allocated)
31 3805 : , debug_hex(false)
32 3805 : , replicable(module.is_clonable())
33 3805 : , debug_limit(-1)
34 3805 : , debug_precision(2)
35 3805 : , debug_frame_max(-1)
36 3805 : , codelet(
37 0 : [](module::Module& /*m*/, Task& /*t*/, const size_t /*frame_id*/) -> int
38 : {
39 0 : throw tools::unimplemented_error(__FILE__, __LINE__, __func__);
40 : return 0;
41 : })
42 3805 : , n_input_sockets(0)
43 3805 : , n_output_sockets(0)
44 3805 : , n_fwd_sockets(0)
45 3805 : , status(module.get_n_waves())
46 3805 : , n_calls(0)
47 3805 : , duration_total(std::chrono::nanoseconds(0))
48 3805 : , duration_min(std::chrono::nanoseconds(0))
49 3805 : , duration_max(std::chrono::nanoseconds(0))
50 7610 : , last_input_socket(nullptr)
51 : {
52 3805 : }
53 :
54 : Socket&
55 0 : Task::operator[](const std::string& sck_name)
56 : {
57 0 : std::string s_name = sck_name;
58 0 : s_name.erase(remove(s_name.begin(), s_name.end(), ' '), s_name.end());
59 :
60 0 : auto it = find_if(this->sockets.begin(),
61 : this->sockets.end(),
62 0 : [s_name](std::shared_ptr<runtime::Socket> s) { return s->get_name() == s_name; });
63 :
64 0 : if (it == this->sockets.end())
65 : {
66 0 : std::stringstream message;
67 0 : message << "runtime::Socket '" << s_name << "' not found for task '" << this->get_name() << "'.";
68 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
69 0 : }
70 :
71 0 : return *it->get();
72 0 : }
73 :
74 : void
75 90197 : Task::set_stats(const bool stats)
76 : {
77 90197 : this->stats = stats;
78 90197 : }
79 :
80 : void
81 96381 : Task::set_fast(const bool fast)
82 : {
83 96381 : this->fast = fast;
84 349437 : for (size_t i = 0; i < sockets.size(); i++)
85 253056 : sockets[i]->set_fast(this->fast);
86 96381 : }
87 :
88 : void
89 90105 : Task::set_debug(const bool debug)
90 : {
91 90105 : this->debug = debug;
92 90105 : }
93 :
94 : void
95 0 : Task::set_debug_hex(const bool debug_hex)
96 : {
97 0 : this->debug_hex = debug_hex;
98 0 : }
99 :
100 : void
101 90105 : Task::set_debug_limit(const uint32_t limit)
102 : {
103 90105 : this->debug_limit = (int32_t)limit;
104 90105 : }
105 :
106 : void
107 0 : Task::set_debug_precision(const uint8_t prec)
108 : {
109 0 : this->debug_precision = prec;
110 0 : }
111 :
112 : void
113 0 : Task::set_debug_frame_max(const uint32_t limit)
114 : {
115 0 : this->debug_frame_max = limit;
116 0 : }
117 :
118 : // trick to compile on the GNU compiler version 4 (where 'std::hexfloat' is unavailable)
119 : #if !defined(__clang__) && !defined(__llvm__) && defined(__GNUC__) && defined(__cplusplus) && __GNUC__ < 5
120 : namespace std
121 : {
122 : class Hexfloat
123 : {
124 : public:
125 : void message(std::ostream& os) const { os << " /!\\ 'std::hexfloat' is not supported by this compiler. /!\\ "; }
126 : };
127 : Hexfloat hexfloat;
128 : }
129 : std::ostream&
130 : operator<<(std::ostream& os, const std::Hexfloat& obj)
131 : {
132 : obj.message(os);
133 : return os;
134 : }
135 : #endif
136 :
137 : template<typename T>
138 : static inline void
139 401 : display_data(const T* data,
140 : const size_t fra_size,
141 : const size_t n_fra,
142 : const size_t n_fra_per_w,
143 : const size_t limit,
144 : const size_t max_frame,
145 : const uint8_t p,
146 : const uint8_t n_spaces,
147 : const bool hex)
148 : {
149 401 : constexpr bool is_float_type = std::is_same<float, T>::value || std::is_same<double, T>::value;
150 :
151 401 : std::ios::fmtflags f(std::cout.flags());
152 401 : if (hex)
153 : {
154 : if (is_float_type)
155 0 : std::cout << std::hexfloat << std::hex;
156 : else
157 0 : std::cout << std::hex;
158 : }
159 : else
160 401 : std::cout << std::fixed << std::setprecision(p) << std::dec;
161 :
162 400 : if (n_fra == 1 && max_frame != 0)
163 : {
164 5600 : for (size_t i = 0; i < limit; i++)
165 : {
166 5198 : if (hex)
167 0 : std::cout << (!is_float_type ? "0x" : "") << +data[i] << (i < limit - 1 ? ", " : "");
168 : else
169 5198 : std::cout << std::setw(p + 3) << +data[i] << (i < limit - 1 ? ", " : "");
170 : }
171 402 : std::cout << (limit < fra_size ? ", ..." : "");
172 401 : }
173 : else
174 : {
175 0 : std::string spaces = "#";
176 0 : for (uint8_t s = 0; s < n_spaces - 1; s++)
177 0 : spaces += " ";
178 :
179 0 : auto n_digits_dec = [](size_t f) -> size_t
180 : {
181 0 : size_t count = 0;
182 0 : while (f && ++count)
183 0 : f /= 10;
184 0 : return count;
185 : };
186 :
187 0 : const auto n_digits = n_digits_dec(max_frame);
188 0 : auto ftos = [&n_digits_dec, &n_digits](size_t f) -> std::string
189 : {
190 0 : auto n_zero = n_digits - n_digits_dec(f);
191 0 : std::string f_str = "";
192 0 : for (size_t z = 0; z < n_zero; z++)
193 0 : f_str += "0";
194 0 : f_str += std::to_string(f);
195 0 : return f_str;
196 0 : };
197 :
198 0 : const auto n_digits_w = n_digits_dec((max_frame / n_fra_per_w) == 0 ? 1 : (max_frame / n_fra_per_w));
199 0 : auto wtos = [&n_digits_dec, &n_digits_w](size_t w) -> std::string
200 : {
201 0 : auto n_zero = n_digits_w - n_digits_dec(w);
202 0 : std::string f_str = "";
203 0 : for (size_t z = 0; z < n_zero; z++)
204 0 : f_str += "0";
205 0 : f_str += std::to_string(w);
206 0 : return f_str;
207 0 : };
208 :
209 0 : for (size_t f = 0; f < max_frame; f++)
210 : {
211 0 : const auto w = f / n_fra_per_w;
212 0 : std::cout << (f >= 1 ? spaces : "") << rang::style::bold << rang::fg::gray << "f" << ftos(f + 1) << "_w"
213 0 : << wtos(w + 1) << rang::style::reset << "(";
214 :
215 0 : for (size_t i = 0; i < limit; i++)
216 : {
217 0 : if (hex)
218 0 : std::cout << (!is_float_type ? "0x" : "") << +data[f * fra_size + i] << (i < limit - 1 ? ", " : "");
219 : else
220 0 : std::cout << std::setw(p + 3) << +data[f * fra_size + i] << (i < limit - 1 ? ", " : "");
221 : }
222 0 : std::cout << (limit < fra_size ? ", ..." : "") << ")" << (f < n_fra - 1 ? ", \n" : "");
223 : }
224 :
225 0 : if (max_frame < n_fra)
226 : {
227 0 : const auto w1 = max_frame / n_fra_per_w;
228 0 : const auto w2 = n_fra / n_fra_per_w;
229 0 : std::cout << (max_frame >= 1 ? spaces : "") << rang::style::bold << rang::fg::gray << "f"
230 0 : << std::setw(n_digits) << max_frame + 1 << "_w" << std::setw(n_digits_w) << w1 + 1 << "->"
231 0 : << "f" << std::setw(n_digits) << n_fra << "_w" << std::setw(n_digits_w) << w2 + 1 << ":"
232 0 : << rang::style::reset << "(...)";
233 : }
234 0 : }
235 :
236 401 : std::cout.flags(f);
237 401 : }
238 :
239 : void
240 3899804 : Task::_exec(const int frame_id, const bool managed_memory)
241 : {
242 3899804 : const auto n_frames = this->get_module().get_n_frames();
243 3869702 : const auto n_frames_per_wave = this->get_module().get_n_frames_per_wave();
244 3827502 : const auto n_waves = this->get_module().get_n_waves();
245 3785585 : const auto n_frames_per_wave_rest = this->get_module().get_n_frames_per_wave_rest();
246 :
247 : // do not use 'this->status' because the dataptr can have been changed by the 'tools::Sequence' when using the no
248 : // copy mode
249 3758478 : int* status = this->sockets.back()->get_dataptr<int>();
250 9545650 : for (size_t w = 0; w < n_waves; w++)
251 5996118 : status[w] = (int)status_t::UNKNOWN;
252 :
253 3549532 : if ((managed_memory == false && frame_id >= 0) || (frame_id == -1 && n_frames_per_wave == n_frames) ||
254 222018 : (frame_id == 0 && n_frames_per_wave == 1) || (frame_id == 0 && n_waves > 1) ||
255 0 : (frame_id == 0 && n_frames_per_wave_rest == 0))
256 : {
257 3329875 : const auto real_frame_id = frame_id == -1 ? 0 : frame_id;
258 3329875 : const size_t w = (real_frame_id % n_frames) / n_frames_per_wave;
259 3329875 : status[w] = this->codelet(*this->module, *this, real_frame_id);
260 3559030 : }
261 : else
262 : {
263 : // save the initial dataptr of the sockets
264 613243 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
265 387849 : sockets_dataptr_init[sid] = (int8_t*)this->sockets[sid]->_get_dataptr();
266 :
267 222135 : if (frame_id > 0 && managed_memory == true && n_frames_per_wave > 1)
268 : {
269 0 : const size_t w = (frame_id % n_frames) / n_frames_per_wave;
270 0 : const size_t w_pos = frame_id % n_frames_per_wave;
271 :
272 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
273 : {
274 0 : if (sockets[sid]->get_type() == socket_t::SIN || sockets[sid]->get_type() == socket_t::SFWD)
275 0 : std::copy(
276 0 : sockets_dataptr_init[sid] + ((frame_id % n_frames) + 0) * sockets_databytes_per_frame[sid],
277 0 : sockets_dataptr_init[sid] + ((frame_id % n_frames) + 1) * sockets_databytes_per_frame[sid],
278 0 : sockets_data[sid].begin() + w_pos * sockets_databytes_per_frame[sid]);
279 0 : this->sockets[sid]->dataptr = (void*)sockets_data[sid].data();
280 : }
281 :
282 0 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
283 :
284 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
285 0 : if (sockets[sid]->get_type() == socket_t::SOUT || sockets[sid]->get_type() == socket_t::SFWD)
286 0 : std::copy(sockets_data[sid].begin() + (w_pos + 0) * sockets_databytes_per_frame[sid],
287 0 : sockets_data[sid].begin() + (w_pos + 1) * sockets_databytes_per_frame[sid],
288 0 : sockets_dataptr_init[sid] + (frame_id % n_frames) * sockets_databytes_per_frame[sid]);
289 0 : }
290 : else // if (frame_id <= 0 || n_frames_per_wave == 1)
291 : {
292 222135 : const size_t w_start = (frame_id < 0) ? 0 : frame_id % n_waves;
293 222135 : const size_t w_stop = (frame_id < 0) ? n_waves : w_start + 1;
294 :
295 222135 : size_t w = 0;
296 222135 : auto exec_status = status_t::SUCCESS;
297 2307098 : for (w = w_start; w < w_stop - 1 && exec_status != status_t::FAILURE_STOP; w++)
298 : {
299 5656644 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
300 3566658 : this->sockets[sid]->dataptr =
301 7131047 : (void*)(sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
302 :
303 1659585 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
304 2084963 : exec_status = (status_t)status[w];
305 : }
306 :
307 222042 : if (exec_status != status_t::FAILURE_STOP)
308 : {
309 222289 : if (n_frames_per_wave_rest == 0)
310 : {
311 613397 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
312 391086 : this->sockets[sid]->dataptr =
313 781960 : (void*)(sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
314 :
315 212320 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
316 : }
317 : else
318 : {
319 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
320 : {
321 0 : if (sockets[sid]->get_type() == socket_t::SIN || sockets[sid]->get_type() == socket_t::SFWD)
322 0 : std::copy(sockets_dataptr_init[sid] +
323 0 : w * n_frames_per_wave * sockets_databytes_per_frame[sid],
324 0 : sockets_dataptr_init[sid] + n_frames * sockets_databytes_per_frame[sid],
325 0 : sockets_data[sid].begin());
326 0 : this->sockets[sid]->dataptr = (void*)sockets_data[sid].data();
327 : }
328 :
329 0 : status[w] = this->codelet(*this->module, *this, w * n_frames_per_wave);
330 :
331 0 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
332 0 : if (sockets[sid]->get_type() == socket_t::SOUT || sockets[sid]->get_type() == socket_t::SFWD)
333 0 : std::copy(
334 0 : sockets_data[sid].begin(),
335 0 : sockets_data[sid].begin() + n_frames_per_wave_rest * sockets_databytes_per_frame[sid],
336 0 : sockets_dataptr_init[sid] + w * n_frames_per_wave * sockets_databytes_per_frame[sid]);
337 : }
338 : }
339 : }
340 :
341 : // restore the initial dataptr of the sockets
342 612623 : for (size_t sid = 0; sid < this->sockets.size() - 1; sid++)
343 390417 : this->sockets[sid]->dataptr = (void*)sockets_dataptr_init[sid];
344 : }
345 3772431 : }
346 :
347 : const std::vector<int>&
348 3809249 : Task::exec(const int frame_id, const bool managed_memory)
349 : {
350 : #ifndef SPU_FAST
351 3809249 : if (this->is_fast() && !this->is_debug() && !this->is_stats())
352 : {
353 : #endif
354 3833438 : this->_exec(frame_id, managed_memory);
355 3730371 : this->n_calls++;
356 3730371 : return this->get_status();
357 : #ifndef SPU_FAST
358 : }
359 :
360 6513 : if (frame_id != -1 && (size_t)frame_id >= this->get_module().get_n_frames())
361 : {
362 0 : std::stringstream message;
363 0 : message << "'frame_id' has to be equal to '-1' or to be smaller than 'n_frames' ('frame_id' = " << frame_id
364 0 : << ", 'n_frames' = " << this->get_module().get_n_frames() << ").";
365 0 : throw tools::length_error(__FILE__, __LINE__, __func__, message.str());
366 0 : }
367 :
368 8781 : if (this->is_fast() || this->can_exec())
369 : {
370 8766 : size_t max_n_chars = 0;
371 8766 : if (this->is_debug())
372 : {
373 333 : auto n_fra = this->module->get_n_frames();
374 333 : auto n_fra_per_w = this->module->get_n_frames_per_wave();
375 :
376 : std::string module_name =
377 333 : module->get_custom_name().empty() ? module->get_name() : module->get_custom_name();
378 :
379 333 : std::cout << "# ";
380 333 : std::cout << rang::style::bold << rang::fg::green << module_name << rang::style::reset
381 333 : << "::" << rang::style::bold << rang::fg::magenta << get_name() << rang::style::reset << "(";
382 736 : for (auto i = 0; i < (int)sockets.size() - 1; i++)
383 : {
384 403 : auto& s = *sockets[i];
385 403 : auto s_type = s.get_type();
386 403 : auto n_elmts = s.get_databytes() / (size_t)s.get_datatype_size();
387 403 : std::cout << rang::style::bold << rang::fg::blue << (s_type == socket_t::SIN ? "const " : "")
388 403 : << s.get_datatype_string() << rang::style::reset << " " << s.get_name() << "["
389 806 : << (n_fra > 1 ? std::to_string(n_fra) + "x" : "") << (n_elmts / n_fra) << "]"
390 403 : << (i < (int)sockets.size() - 2 ? ", " : "");
391 :
392 403 : max_n_chars = std::max(s.get_name().size(), max_n_chars);
393 : }
394 333 : std::cout << ")" << std::endl;
395 :
396 1069 : for (auto& s : sockets)
397 : {
398 736 : auto s_type = s->get_type();
399 736 : if (s_type == socket_t::SIN || s_type == socket_t::SFWD)
400 : {
401 217 : std::string spaces;
402 345 : for (size_t ss = 0; ss < max_n_chars - s->get_name().size(); ss++)
403 128 : spaces += " ";
404 :
405 217 : auto n_elmts = s->get_databytes() / (size_t)s->get_datatype_size();
406 217 : auto fra_size = n_elmts / n_fra;
407 217 : auto limit = debug_limit != -1 ? std::min(fra_size, (size_t)debug_limit) : fra_size;
408 217 : auto max_frame = debug_frame_max != -1 ? std::min(n_fra, (size_t)debug_frame_max) : n_fra;
409 217 : auto p = debug_precision;
410 217 : auto h = debug_hex;
411 217 : std::cout << "# {IN} " << s->get_name() << spaces << " = [";
412 217 : if (s->get_datatype() == typeid(int8_t))
413 0 : display_data(s->get_dataptr<const int8_t>(),
414 : fra_size,
415 : n_fra,
416 : n_fra_per_w,
417 : limit,
418 : max_frame,
419 : p,
420 0 : (uint8_t)max_n_chars + 12,
421 : h);
422 217 : else if (s->get_datatype() == typeid(uint8_t))
423 172 : display_data(s->get_dataptr<const uint8_t>(),
424 : fra_size,
425 : n_fra,
426 : n_fra_per_w,
427 : limit,
428 : max_frame,
429 : p,
430 172 : (uint8_t)max_n_chars + 12,
431 : h);
432 45 : else if (s->get_datatype() == typeid(int16_t))
433 0 : display_data(s->get_dataptr<const int16_t>(),
434 : fra_size,
435 : n_fra,
436 : n_fra_per_w,
437 : limit,
438 : max_frame,
439 : p,
440 0 : (uint8_t)max_n_chars + 12,
441 : h);
442 45 : else if (s->get_datatype() == typeid(uint16_t))
443 0 : display_data(s->get_dataptr<const uint16_t>(),
444 : fra_size,
445 : n_fra,
446 : n_fra_per_w,
447 : limit,
448 : max_frame,
449 : p,
450 0 : (uint8_t)max_n_chars + 12,
451 : h);
452 45 : else if (s->get_datatype() == typeid(int32_t))
453 0 : display_data(s->get_dataptr<const int32_t>(),
454 : fra_size,
455 : n_fra,
456 : n_fra_per_w,
457 : limit,
458 : max_frame,
459 : p,
460 0 : (uint8_t)max_n_chars + 12,
461 : h);
462 45 : else if (s->get_datatype() == typeid(uint32_t))
463 27 : display_data(s->get_dataptr<const uint32_t>(),
464 : fra_size,
465 : n_fra,
466 : n_fra_per_w,
467 : limit,
468 : max_frame,
469 : p,
470 27 : (uint8_t)max_n_chars + 12,
471 : h);
472 18 : else if (s->get_datatype() == typeid(int64_t))
473 0 : display_data(s->get_dataptr<const int64_t>(),
474 : fra_size,
475 : n_fra,
476 : n_fra_per_w,
477 : limit,
478 : max_frame,
479 : p,
480 0 : (uint8_t)max_n_chars + 12,
481 : h);
482 18 : else if (s->get_datatype() == typeid(uint64_t))
483 18 : display_data(s->get_dataptr<const uint64_t>(),
484 : fra_size,
485 : n_fra,
486 : n_fra_per_w,
487 : limit,
488 : max_frame,
489 : p,
490 18 : (uint8_t)max_n_chars + 12,
491 : h);
492 0 : else if (s->get_datatype() == typeid(float))
493 0 : display_data(s->get_dataptr<const float>(),
494 : fra_size,
495 : n_fra,
496 : n_fra_per_w,
497 : limit,
498 : max_frame,
499 : p,
500 0 : (uint8_t)max_n_chars + 12,
501 : h);
502 0 : else if (s->get_datatype() == typeid(double))
503 0 : display_data(s->get_dataptr<const double>(),
504 : fra_size,
505 : n_fra,
506 : n_fra_per_w,
507 : limit,
508 : max_frame,
509 : p,
510 0 : (uint8_t)max_n_chars + 12,
511 : h);
512 217 : std::cout << "]" << std::endl;
513 217 : }
514 : }
515 333 : }
516 :
517 8760 : if (this->is_stats())
518 : {
519 8048 : auto t_start = std::chrono::steady_clock::now();
520 8061 : this->_exec(frame_id, managed_memory);
521 7950 : auto duration = std::chrono::steady_clock::now() - t_start;
522 :
523 7961 : this->duration_total += duration;
524 7947 : if (n_calls)
525 : {
526 7490 : this->duration_min = std::min(this->duration_min, duration);
527 7493 : this->duration_max = std::max(this->duration_max, duration);
528 : }
529 : else
530 : {
531 457 : this->duration_min = duration;
532 457 : this->duration_max = duration;
533 : }
534 : }
535 : else
536 : {
537 712 : this->_exec(frame_id, managed_memory);
538 : }
539 8682 : this->n_calls++;
540 :
541 8682 : if (this->is_debug())
542 : {
543 330 : auto n_fra = this->module->get_n_frames();
544 330 : auto n_fra_per_w = this->module->get_n_frames_per_wave();
545 1058 : for (auto& s : sockets)
546 : {
547 728 : auto s_type = s->get_type();
548 728 : if ((s_type == socket_t::SOUT) && s->get_name() != "status")
549 : {
550 184 : std::string spaces;
551 202 : for (size_t ss = 0; ss < max_n_chars - s->get_name().size(); ss++)
552 18 : spaces += " ";
553 :
554 184 : auto n_elmts = s->get_databytes() / (size_t)s->get_datatype_size();
555 184 : auto fra_size = n_elmts / n_fra;
556 184 : auto limit = debug_limit != -1 ? std::min(fra_size, (size_t)debug_limit) : fra_size;
557 184 : auto max_frame = debug_frame_max != -1 ? std::min(n_fra, (size_t)debug_frame_max) : n_fra;
558 184 : auto p = debug_precision;
559 184 : auto h = debug_hex;
560 184 : std::cout << "# {OUT} " << s->get_name() << spaces << " = [";
561 184 : if (s->get_datatype() == typeid(int8_t))
562 0 : display_data(s->get_dataptr<const int8_t>(),
563 : fra_size,
564 : n_fra,
565 : n_fra_per_w,
566 : limit,
567 : max_frame,
568 : p,
569 0 : (uint8_t)max_n_chars + 12,
570 : h);
571 184 : else if (s->get_datatype() == typeid(uint8_t))
572 147 : display_data(s->get_dataptr<const uint8_t>(),
573 : fra_size,
574 : n_fra,
575 : n_fra_per_w,
576 : limit,
577 : max_frame,
578 : p,
579 148 : (uint8_t)max_n_chars + 12,
580 : h);
581 36 : else if (s->get_datatype() == typeid(int16_t))
582 0 : display_data(s->get_dataptr<const int16_t>(),
583 : fra_size,
584 : n_fra,
585 : n_fra_per_w,
586 : limit,
587 : max_frame,
588 : p,
589 0 : (uint8_t)max_n_chars + 12,
590 : h);
591 36 : else if (s->get_datatype() == typeid(uint16_t))
592 0 : display_data(s->get_dataptr<const uint16_t>(),
593 : fra_size,
594 : n_fra,
595 : n_fra_per_w,
596 : limit,
597 : max_frame,
598 : p,
599 0 : (uint8_t)max_n_chars + 12,
600 : h);
601 36 : else if (s->get_datatype() == typeid(int32_t))
602 0 : display_data(s->get_dataptr<const int32_t>(),
603 : fra_size,
604 : n_fra,
605 : n_fra_per_w,
606 : limit,
607 : max_frame,
608 : p,
609 0 : (uint8_t)max_n_chars + 12,
610 : h);
611 36 : else if (s->get_datatype() == typeid(uint32_t))
612 18 : display_data(s->get_dataptr<const uint32_t>(),
613 : fra_size,
614 : n_fra,
615 : n_fra_per_w,
616 : limit,
617 : max_frame,
618 : p,
619 18 : (uint8_t)max_n_chars + 12,
620 : h);
621 18 : else if (s->get_datatype() == typeid(int64_t))
622 0 : display_data(s->get_dataptr<const int64_t>(),
623 : fra_size,
624 : n_fra,
625 : n_fra_per_w,
626 : limit,
627 : max_frame,
628 : p,
629 0 : (uint8_t)max_n_chars + 12,
630 : h);
631 18 : else if (s->get_datatype() == typeid(uint64_t))
632 18 : display_data(s->get_dataptr<const uint64_t>(),
633 : fra_size,
634 : n_fra,
635 : n_fra_per_w,
636 : limit,
637 : max_frame,
638 : p,
639 18 : (uint8_t)max_n_chars + 12,
640 : h);
641 0 : else if (s->get_datatype() == typeid(float))
642 0 : display_data(s->get_dataptr<const float>(),
643 : fra_size,
644 : n_fra,
645 : n_fra_per_w,
646 : limit,
647 : max_frame,
648 : p,
649 0 : (uint8_t)max_n_chars + 12,
650 : h);
651 0 : else if (s->get_datatype() == typeid(double))
652 0 : display_data(s->get_dataptr<const double>(),
653 : fra_size,
654 : n_fra,
655 : n_fra_per_w,
656 : limit,
657 : max_frame,
658 : p,
659 0 : (uint8_t)max_n_chars + 12,
660 : h);
661 184 : std::cout << "]" << std::endl;
662 184 : }
663 : }
664 330 : std::cout << "# Returned status: [";
665 : // do not use 'this->status' because the dataptr can have been changed by the 'tools::Sequence' when using
666 : // the no copy mode
667 330 : int* status = this->sockets.back()->get_dataptr<int>();
668 660 : for (size_t w = 0; w < this->get_module().get_n_waves(); w++)
669 : {
670 330 : if (status_t_to_string.count(status[w]))
671 330 : std::cout << ((w != 0) ? ", " : "") << std::dec << status[w] << " '"
672 330 : << status_t_to_string[status[w]] << "'";
673 : else
674 0 : std::cout << ((w != 0) ? ", " : "") << std::dec << status[w];
675 : }
676 330 : std::cout << "]" << std::endl;
677 330 : std::cout << "#" << std::noshowbase << std::endl;
678 : }
679 :
680 : // if (exec_status < 0)
681 : // {
682 : // std::stringstream message;
683 : // message << "'exec_status' can't be negative ('exec_status' = " << exec_status << ").";
684 : // throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
685 : // }
686 :
687 8674 : return this->get_status();
688 : }
689 : else
690 : {
691 0 : std::stringstream input_socs;
692 0 : std::stringstream output_socs;
693 0 : input_socs << "'socket(s).name' = [";
694 0 : output_socs << "'socket(s).name' = [";
695 0 : auto s = 0;
696 0 : for (size_t i = 0; i < sockets.size(); i++)
697 : {
698 0 : if (sockets[i]->dataptr == nullptr)
699 : {
700 0 : if (sockets[i]->get_type() == runtime::socket_t::SOUT)
701 : {
702 0 : output_socs << (s != 0 ? ", " : "") << sockets[i]->name;
703 : }
704 : else
705 : {
706 0 : input_socs << (s != 0 ? ", " : "") << sockets[i]->name;
707 : }
708 : }
709 : }
710 0 : std::stringstream message;
711 0 : message << "The task ('task.name' = " << this->get_name() << ", 'module.name' = " << module->get_name()
712 0 : << ", cannot be executed because : " << std::endl
713 0 : << "The inputs/forward sockets : " << input_socs.str() << "] are not fed" << std::endl
714 0 : << "The output sockets : " << output_socs.str() << "] are not allocated";
715 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
716 0 : }
717 : #endif /* !SPU_FAST */
718 : }
719 :
720 : template<typename T>
721 : Socket&
722 9462 : Task::create_2d_socket(const std::string& name,
723 : const size_t n_rows,
724 : const size_t n_cols,
725 : const socket_t type,
726 : const bool hack_status)
727 : {
728 9462 : if (name.empty())
729 : {
730 0 : std::stringstream message;
731 : message << "Impossible to create this socket because the name is empty ('task.name' = " << this->get_name()
732 0 : << ", 'module.name' = " << module->get_name() << ").";
733 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
734 0 : }
735 :
736 9462 : if (name == "status" && !hack_status)
737 : {
738 0 : std::stringstream message;
739 0 : message << "A socket can't be named 'status'.";
740 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
741 0 : }
742 :
743 17683 : for (auto& s : sockets)
744 8221 : if (s->get_name() == name)
745 : {
746 0 : std::stringstream message;
747 : message << "Impossible to create this socket because an other socket has the same name ('socket.name' = "
748 0 : << name << ", 'task.name' = " << this->get_name() << ", 'module.name' = " << module->get_name()
749 0 : << ").";
750 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
751 0 : }
752 :
753 17683 : for (auto s : this->sockets)
754 8221 : if (s->get_name() == "status")
755 : {
756 0 : std::stringstream message;
757 0 : message << "Creating new sockets after the 'status' socket is forbidden.";
758 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
759 0 : }
760 :
761 9462 : std::pair<size_t, size_t> databytes_per_dim = { n_rows, n_cols * sizeof(T) };
762 9462 : auto s = std::make_shared<Socket>(*this, name, typeid(T), databytes_per_dim, type, this->is_fast());
763 :
764 9462 : sockets.push_back(std::move(s));
765 :
766 9462 : this->sockets_dataptr_init.push_back(nullptr);
767 9462 : this->sockets_databytes_per_frame.push_back(sockets.back()->get_databytes() / this->get_module().get_n_frames());
768 9462 : this->sockets_data.push_back(
769 18924 : std::vector<int8_t>((this->get_module().get_n_frames_per_wave() > 1)
770 0 : ? this->sockets_databytes_per_frame.back() * this->get_module().get_n_frames_per_wave()
771 : : 0));
772 :
773 18924 : return *sockets.back();
774 9462 : }
775 :
776 : template<typename T>
777 : size_t
778 2306 : Task::create_2d_socket_in(const std::string& name, const size_t n_rows, const size_t n_cols)
779 : {
780 2306 : auto& s = create_2d_socket<T>(name, n_rows, n_cols, socket_t::SIN);
781 2306 : last_input_socket = &s;
782 :
783 2306 : this->n_input_sockets++;
784 :
785 2306 : return sockets.size() - 1;
786 : }
787 :
788 : size_t
789 747 : Task::create_2d_socket_in(const std::string& name,
790 : const size_t n_rows,
791 : const size_t n_cols,
792 : const std::type_index& datatype)
793 : {
794 747 : if (datatype == typeid(int8_t))
795 74 : return this->template create_2d_socket_in<int8_t>(name, n_rows, n_cols);
796 673 : else if (datatype == typeid(uint8_t))
797 507 : return this->template create_2d_socket_in<uint8_t>(name, n_rows, n_cols);
798 166 : else if (datatype == typeid(int16_t))
799 0 : return this->template create_2d_socket_in<int16_t>(name, n_rows, n_cols);
800 166 : else if (datatype == typeid(uint16_t))
801 0 : return this->template create_2d_socket_in<uint16_t>(name, n_rows, n_cols);
802 166 : else if (datatype == typeid(int32_t))
803 33 : return this->template create_2d_socket_in<int32_t>(name, n_rows, n_cols);
804 133 : else if (datatype == typeid(uint32_t))
805 115 : return this->template create_2d_socket_in<uint32_t>(name, n_rows, n_cols);
806 18 : else if (datatype == typeid(int64_t))
807 0 : return this->template create_2d_socket_in<int64_t>(name, n_rows, n_cols);
808 18 : else if (datatype == typeid(uint64_t))
809 18 : return this->template create_2d_socket_in<uint64_t>(name, n_rows, n_cols);
810 0 : else if (datatype == typeid(float))
811 0 : return this->template create_2d_socket_in<float>(name, n_rows, n_cols);
812 0 : else if (datatype == typeid(double))
813 0 : return this->template create_2d_socket_in<double>(name, n_rows, n_cols);
814 : else
815 : {
816 0 : std::stringstream message;
817 0 : message << "This should never happen.";
818 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
819 0 : }
820 : }
821 :
822 : size_t
823 0 : Task::create_2d_socket_in(const std::string& name, const size_t n_rows, const size_t n_cols, const datatype_t datatype)
824 : {
825 0 : switch (datatype)
826 : {
827 0 : case datatype_t::F64:
828 0 : return this->template create_2d_socket_in<double>(name, n_rows, n_cols);
829 : break;
830 0 : case datatype_t::F32:
831 0 : return this->template create_2d_socket_in<float>(name, n_rows, n_cols);
832 : break;
833 0 : case datatype_t::S64:
834 0 : return this->template create_2d_socket_in<int64_t>(name, n_rows, n_cols);
835 : break;
836 0 : case datatype_t::S32:
837 0 : return this->template create_2d_socket_in<int32_t>(name, n_rows, n_cols);
838 : break;
839 0 : case datatype_t::S16:
840 0 : return this->template create_2d_socket_in<int16_t>(name, n_rows, n_cols);
841 : break;
842 0 : case datatype_t::S8:
843 0 : return this->template create_2d_socket_in<int8_t>(name, n_rows, n_cols);
844 : break;
845 0 : case datatype_t::U64:
846 0 : return this->template create_2d_socket_in<uint64_t>(name, n_rows, n_cols);
847 : break;
848 0 : case datatype_t::U32:
849 0 : return this->template create_2d_socket_in<uint32_t>(name, n_rows, n_cols);
850 : break;
851 0 : case datatype_t::U16:
852 0 : return this->template create_2d_socket_in<uint16_t>(name, n_rows, n_cols);
853 : break;
854 0 : case datatype_t::U8:
855 0 : return this->template create_2d_socket_in<uint8_t>(name, n_rows, n_cols);
856 : break;
857 0 : default:
858 : {
859 0 : std::stringstream message;
860 0 : message << "This should never happen.";
861 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
862 : break;
863 0 : }
864 : }
865 : }
866 :
867 : template<typename T>
868 : size_t
869 5973 : Task::create_2d_socket_out(const std::string& name, const size_t n_rows, const size_t n_cols, const bool hack_status)
870 : {
871 5973 : create_2d_socket<T>(name, n_rows, n_cols, socket_t::SOUT, hack_status);
872 5973 : this->n_output_sockets++;
873 :
874 5973 : return sockets.size() - 1;
875 : }
876 :
877 : size_t
878 740 : Task::create_2d_socket_out(const std::string& name,
879 : const size_t n_rows,
880 : const size_t n_cols,
881 : const std::type_index& datatype,
882 : const bool hack_status)
883 : {
884 740 : if (datatype == typeid(int8_t))
885 67 : return this->template create_2d_socket_out<int8_t>(name, n_rows, n_cols, hack_status);
886 673 : else if (datatype == typeid(uint8_t))
887 507 : return this->template create_2d_socket_out<uint8_t>(name, n_rows, n_cols, hack_status);
888 166 : else if (datatype == typeid(int16_t))
889 0 : return this->template create_2d_socket_out<int16_t>(name, n_rows, n_cols, hack_status);
890 166 : else if (datatype == typeid(uint16_t))
891 0 : return this->template create_2d_socket_out<uint16_t>(name, n_rows, n_cols, hack_status);
892 166 : else if (datatype == typeid(int32_t))
893 33 : return this->template create_2d_socket_out<int32_t>(name, n_rows, n_cols, hack_status);
894 133 : else if (datatype == typeid(uint32_t))
895 115 : return this->template create_2d_socket_out<uint32_t>(name, n_rows, n_cols, hack_status);
896 18 : else if (datatype == typeid(int64_t))
897 0 : return this->template create_2d_socket_out<int64_t>(name, n_rows, n_cols, hack_status);
898 18 : else if (datatype == typeid(uint64_t))
899 18 : return this->template create_2d_socket_out<uint64_t>(name, n_rows, n_cols, hack_status);
900 0 : else if (datatype == typeid(float))
901 0 : return this->template create_2d_socket_out<float>(name, n_rows, n_cols, hack_status);
902 0 : else if (datatype == typeid(double))
903 0 : return this->template create_2d_socket_out<double>(name, n_rows, n_cols, hack_status);
904 : else
905 : {
906 0 : std::stringstream message;
907 0 : message << "This should never happen.";
908 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
909 0 : }
910 : }
911 :
912 : size_t
913 0 : Task::create_2d_socket_out(const std::string& name,
914 : const size_t n_rows,
915 : const size_t n_cols,
916 : const datatype_t datatype,
917 : const bool hack_status)
918 : {
919 0 : switch (datatype)
920 : {
921 0 : case datatype_t::F64:
922 0 : return this->template create_2d_socket_out<double>(name, n_rows, n_cols, hack_status);
923 : break;
924 0 : case datatype_t::F32:
925 0 : return this->template create_2d_socket_out<float>(name, n_rows, n_cols, hack_status);
926 : break;
927 0 : case datatype_t::S64:
928 0 : return this->template create_2d_socket_out<int64_t>(name, n_rows, n_cols, hack_status);
929 : break;
930 0 : case datatype_t::S32:
931 0 : return this->template create_2d_socket_out<int32_t>(name, n_rows, n_cols, hack_status);
932 : break;
933 0 : case datatype_t::S16:
934 0 : return this->template create_2d_socket_out<int16_t>(name, n_rows, n_cols, hack_status);
935 : break;
936 0 : case datatype_t::S8:
937 0 : return this->template create_2d_socket_out<int8_t>(name, n_rows, n_cols, hack_status);
938 : break;
939 0 : case datatype_t::U64:
940 0 : return this->template create_2d_socket_out<uint64_t>(name, n_rows, n_cols, hack_status);
941 : break;
942 0 : case datatype_t::U32:
943 0 : return this->template create_2d_socket_out<uint32_t>(name, n_rows, n_cols, hack_status);
944 : break;
945 0 : case datatype_t::U16:
946 0 : return this->template create_2d_socket_out<uint16_t>(name, n_rows, n_cols, hack_status);
947 : break;
948 0 : case datatype_t::U8:
949 0 : return this->template create_2d_socket_out<uint8_t>(name, n_rows, n_cols, hack_status);
950 : break;
951 0 : default:
952 : {
953 0 : std::stringstream message;
954 0 : message << "This should never happen.";
955 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
956 : break;
957 0 : }
958 : }
959 : }
960 :
961 : template<typename T>
962 : size_t
963 1183 : Task::create_2d_socket_fwd(const std::string& name, const size_t n_rows, const size_t n_cols)
964 : {
965 1183 : auto& s = create_2d_socket<T>(name, n_rows, n_cols, socket_t::SFWD);
966 1183 : last_input_socket = &s;
967 :
968 1183 : this->n_fwd_sockets++;
969 :
970 1183 : return sockets.size() - 1;
971 : }
972 :
973 : size_t
974 0 : Task::create_2d_socket_fwd(const std::string& name,
975 : const size_t n_rows,
976 : const size_t n_cols,
977 : const std::type_index& datatype)
978 : {
979 0 : if (datatype == typeid(int8_t))
980 0 : return this->template create_2d_socket_fwd<int8_t>(name, n_rows, n_cols);
981 0 : else if (datatype == typeid(uint8_t))
982 0 : return this->template create_2d_socket_fwd<uint8_t>(name, n_rows, n_cols);
983 0 : else if (datatype == typeid(int16_t))
984 0 : return this->template create_2d_socket_fwd<int16_t>(name, n_rows, n_cols);
985 0 : else if (datatype == typeid(uint16_t))
986 0 : return this->template create_2d_socket_fwd<uint16_t>(name, n_rows, n_cols);
987 0 : else if (datatype == typeid(int32_t))
988 0 : return this->template create_2d_socket_fwd<int32_t>(name, n_rows, n_cols);
989 0 : else if (datatype == typeid(uint32_t))
990 0 : return this->template create_2d_socket_fwd<uint32_t>(name, n_rows, n_cols);
991 0 : else if (datatype == typeid(int64_t))
992 0 : return this->template create_2d_socket_fwd<int64_t>(name, n_rows, n_cols);
993 0 : else if (datatype == typeid(uint64_t))
994 0 : return this->template create_2d_socket_fwd<uint64_t>(name, n_rows, n_cols);
995 0 : else if (datatype == typeid(float))
996 0 : return this->template create_2d_socket_fwd<float>(name, n_rows, n_cols);
997 0 : else if (datatype == typeid(double))
998 0 : return this->template create_2d_socket_fwd<double>(name, n_rows, n_cols);
999 : else
1000 : {
1001 0 : std::stringstream message;
1002 0 : message << "This should never happen.";
1003 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1004 0 : }
1005 : }
1006 :
1007 : size_t
1008 0 : Task::create_2d_socket_fwd(const std::string& name, const size_t n_rows, const size_t n_cols, const datatype_t datatype)
1009 : {
1010 0 : switch (datatype)
1011 : {
1012 0 : case datatype_t::F64:
1013 0 : return this->template create_2d_socket_fwd<double>(name, n_rows, n_cols);
1014 : break;
1015 0 : case datatype_t::F32:
1016 0 : return this->template create_2d_socket_fwd<float>(name, n_rows, n_cols);
1017 : break;
1018 0 : case datatype_t::S64:
1019 0 : return this->template create_2d_socket_fwd<int64_t>(name, n_rows, n_cols);
1020 : break;
1021 0 : case datatype_t::S32:
1022 0 : return this->template create_2d_socket_fwd<int32_t>(name, n_rows, n_cols);
1023 : break;
1024 0 : case datatype_t::S16:
1025 0 : return this->template create_2d_socket_fwd<int16_t>(name, n_rows, n_cols);
1026 : break;
1027 0 : case datatype_t::S8:
1028 0 : return this->template create_2d_socket_fwd<int8_t>(name, n_rows, n_cols);
1029 : break;
1030 0 : case datatype_t::U64:
1031 0 : return this->template create_2d_socket_fwd<uint64_t>(name, n_rows, n_cols);
1032 : break;
1033 0 : case datatype_t::U32:
1034 0 : return this->template create_2d_socket_fwd<uint32_t>(name, n_rows, n_cols);
1035 : break;
1036 0 : case datatype_t::U16:
1037 0 : return this->template create_2d_socket_fwd<uint16_t>(name, n_rows, n_cols);
1038 : break;
1039 0 : case datatype_t::U8:
1040 0 : return this->template create_2d_socket_fwd<uint8_t>(name, n_rows, n_cols);
1041 : break;
1042 0 : default:
1043 : {
1044 0 : std::stringstream message;
1045 0 : message << "This should never happen.";
1046 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1047 : break;
1048 0 : }
1049 : }
1050 : }
1051 :
1052 : void
1053 3805 : Task::create_codelet(std::function<int(module::Module& m, Task& t, const size_t frame_id)>& codelet)
1054 : {
1055 3805 : this->codelet = codelet;
1056 :
1057 : // create automatically a socket that contains the status of the task
1058 3805 : const bool hack_status = true;
1059 3805 : auto s = this->template create_2d_socket_out<int>("status", 1, this->get_module().get_n_waves(), hack_status);
1060 3805 : this->sockets[s]->dataptr = (void*)this->status.data();
1061 :
1062 3805 : if (tools::Buffer_allocator::get_task_autoalloc()) this->allocate_outbuffers();
1063 3805 : }
1064 :
1065 : void
1066 98510 : Task::update_n_frames(const size_t old_n_frames, const size_t new_n_frames)
1067 : {
1068 353544 : for (auto& s : this->sockets)
1069 : {
1070 255034 : if (s->get_name() == "status")
1071 : {
1072 98510 : if (this->get_module().get_n_waves() * sizeof(int) != s->get_databytes())
1073 : {
1074 86520 : s->set_databytes(this->get_module().get_n_waves() * sizeof(int));
1075 86520 : this->status.resize(this->get_module().get_n_waves());
1076 86520 : s->set_dataptr((void*)this->status.data());
1077 : }
1078 : }
1079 : else
1080 : {
1081 156524 : const auto old_databytes = s->get_databytes();
1082 156524 : const auto new_databytes = (old_databytes / old_n_frames) * new_n_frames;
1083 156524 : s->set_databytes(new_databytes);
1084 :
1085 156524 : const size_t prev_n_rows_wo_nfra = s->get_n_rows() / old_n_frames;
1086 156524 : s->set_n_rows(prev_n_rows_wo_nfra * new_n_frames);
1087 :
1088 156524 : if (s->get_type() == socket_t::SOUT)
1089 : {
1090 62014 : s->resize_out_buffer(new_databytes);
1091 : }
1092 : }
1093 : }
1094 98510 : }
1095 :
1096 : void
1097 23980 : Task::update_n_frames_per_wave(const size_t /*old_n_frames_per_wave*/, const size_t new_n_frames_per_wave)
1098 : {
1099 23980 : size_t s_id = 0;
1100 96944 : for (auto& s : this->sockets)
1101 : {
1102 72964 : if (s->get_name() == "status")
1103 : {
1104 23980 : if (this->get_module().get_n_waves() * sizeof(int) != s->get_databytes())
1105 : {
1106 11990 : s->set_databytes(this->get_module().get_n_waves() * sizeof(int));
1107 11990 : this->status.resize(this->get_module().get_n_waves());
1108 11990 : s->set_dataptr((void*)this->status.data());
1109 : }
1110 : }
1111 : else
1112 : {
1113 73476 : this->sockets_data[s_id].resize(
1114 24492 : (new_n_frames_per_wave > 1) ? this->sockets_databytes_per_frame[s_id] * new_n_frames_per_wave : 0);
1115 : }
1116 72964 : s_id++;
1117 : }
1118 23980 : }
1119 :
1120 : void
1121 59759 : Task::allocate_outbuffers()
1122 : {
1123 59759 : if (!this->is_outbuffers_allocated())
1124 : {
1125 : std::function<void(Socket * socket, void* data_ptr)> spread_dataptr =
1126 55686 : [&spread_dataptr](Socket* socket, void* data_ptr)
1127 : {
1128 94074 : for (auto bound_socket : socket->get_bound_sockets())
1129 : {
1130 44478 : if (bound_socket->get_type() == socket_t::SIN)
1131 : {
1132 38388 : bound_socket->set_dataptr(data_ptr);
1133 : }
1134 6090 : else if (bound_socket->get_type() == socket_t::SFWD)
1135 : {
1136 6090 : bound_socket->set_dataptr(data_ptr);
1137 6090 : spread_dataptr(bound_socket, data_ptr);
1138 : }
1139 : else
1140 : {
1141 0 : std::stringstream message;
1142 0 : message << "bound socket is of type SOUT, but should be SIN or SFWD";
1143 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1144 0 : }
1145 : }
1146 103032 : };
1147 203502 : for (auto s : this->sockets)
1148 : {
1149 150066 : if (s->get_type() == socket_t::SOUT && s->get_name() != "status")
1150 : {
1151 43555 : if (s->get_dataptr() == nullptr)
1152 : {
1153 43506 : s->allocate_buffer();
1154 43506 : spread_dataptr(s.get(), s->get_dataptr());
1155 : }
1156 : }
1157 150066 : }
1158 53436 : this->set_outbuffers_allocated(true);
1159 53436 : }
1160 59759 : }
1161 : void
1162 14777 : Task::deallocate_outbuffers()
1163 : {
1164 14777 : if (!this->is_outbuffers_allocated())
1165 : {
1166 0 : std::stringstream message;
1167 : message << "Task out sockets buffers are not allocated"
1168 0 : << ", task name : " << this->get_name();
1169 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1170 0 : }
1171 13815 : std::function<void(Socket * socket)> spread_nullptr = [&spread_nullptr](Socket* socket)
1172 : {
1173 22360 : for (auto bound_socket : socket->get_bound_sockets())
1174 : {
1175 9439 : if (bound_socket->get_type() == socket_t::SIN)
1176 : {
1177 8545 : bound_socket->set_dataptr(nullptr);
1178 : }
1179 894 : else if (bound_socket->get_type() == socket_t::SFWD)
1180 : {
1181 894 : bound_socket->set_dataptr(nullptr);
1182 894 : spread_nullptr(bound_socket);
1183 : }
1184 0 : else if (dynamic_cast<const module::Set*>(&bound_socket->get_task().get_module()))
1185 : {
1186 : // hack: for set that bind SOUT to SOUT for perf
1187 0 : bound_socket->set_dataptr(nullptr);
1188 : }
1189 : else
1190 : {
1191 0 : std::stringstream message;
1192 0 : message << "bound socket is of type SOUT, but should be SIN or SFWD";
1193 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1194 0 : }
1195 : }
1196 27698 : };
1197 55616 : for (auto s : this->sockets)
1198 : {
1199 40839 : if (s->get_type() == socket_t::SOUT && s->get_name() != "status")
1200 : {
1201 12027 : if (s->get_dataptr() != nullptr)
1202 : {
1203 12027 : s->deallocate_buffer();
1204 12027 : spread_nullptr(s.get());
1205 : }
1206 : }
1207 40839 : }
1208 14777 : this->set_outbuffers_allocated(false);
1209 14777 : }
1210 :
1211 : bool
1212 726 : Task::can_exec() const
1213 : {
1214 2822 : for (size_t i = 0; i < sockets.size(); i++)
1215 2039 : if (sockets[i]->dataptr == nullptr) return false;
1216 710 : return true;
1217 : }
1218 :
1219 : std::chrono::nanoseconds
1220 11036 : Task::get_duration_total() const
1221 : {
1222 11036 : return this->duration_total;
1223 : }
1224 :
1225 : std::chrono::nanoseconds
1226 1381 : Task::get_duration_avg() const
1227 : {
1228 1381 : return this->duration_total / this->n_calls;
1229 : }
1230 :
1231 : std::chrono::nanoseconds
1232 1532 : Task::get_duration_min() const
1233 : {
1234 1532 : return this->duration_min;
1235 : }
1236 :
1237 : std::chrono::nanoseconds
1238 1532 : Task::get_duration_max() const
1239 : {
1240 1532 : return this->duration_max;
1241 : }
1242 :
1243 : const std::vector<std::string>&
1244 385 : Task::get_timers_name() const
1245 : {
1246 385 : return this->timers_name;
1247 : }
1248 :
1249 : const std::vector<uint32_t>&
1250 59 : Task::get_timers_n_calls() const
1251 : {
1252 59 : return this->timers_n_calls;
1253 : }
1254 :
1255 : const std::vector<std::chrono::nanoseconds>&
1256 59 : Task::get_timers_total() const
1257 : {
1258 59 : return this->timers_total;
1259 : }
1260 :
1261 : const std::vector<std::chrono::nanoseconds>&
1262 59 : Task::get_timers_min() const
1263 : {
1264 59 : return this->timers_min;
1265 : }
1266 :
1267 : const std::vector<std::chrono::nanoseconds>&
1268 59 : Task::get_timers_max() const
1269 : {
1270 59 : return this->timers_max;
1271 : }
1272 :
1273 : size_t
1274 3189 : Task::get_n_input_sockets() const
1275 : {
1276 3189 : return this->n_input_sockets;
1277 : }
1278 :
1279 : size_t
1280 0 : Task::get_n_output_sockets() const
1281 : {
1282 0 : return this->n_output_sockets;
1283 : }
1284 :
1285 : size_t
1286 4330 : Task::get_n_fwd_sockets() const
1287 : {
1288 4330 : return this->n_fwd_sockets;
1289 : }
1290 :
1291 : void
1292 0 : Task::register_timer(const std::string& name)
1293 : {
1294 0 : this->timers_name.push_back(name);
1295 0 : this->timers_n_calls.push_back(0);
1296 0 : this->timers_total.push_back(std::chrono::nanoseconds(0));
1297 0 : this->timers_max.push_back(std::chrono::nanoseconds(0));
1298 0 : this->timers_min.push_back(std::chrono::nanoseconds(0));
1299 0 : }
1300 :
1301 : void
1302 90197 : Task::reset()
1303 : {
1304 90197 : this->n_calls = 0;
1305 90197 : this->duration_total = std::chrono::nanoseconds(0);
1306 90197 : this->duration_min = std::chrono::nanoseconds(0);
1307 90197 : this->duration_max = std::chrono::nanoseconds(0);
1308 :
1309 90197 : for (auto& x : this->timers_n_calls)
1310 0 : x = 0;
1311 90197 : for (auto& x : this->timers_total)
1312 0 : x = std::chrono::nanoseconds(0);
1313 90197 : for (auto& x : this->timers_min)
1314 0 : x = std::chrono::nanoseconds(0);
1315 90197 : for (auto& x : this->timers_max)
1316 0 : x = std::chrono::nanoseconds(0);
1317 90197 : }
1318 :
1319 : Task*
1320 83178 : Task::clone() const
1321 : {
1322 83178 : Task* t = new Task(*this);
1323 83178 : t->sockets.clear();
1324 83178 : t->last_input_socket = nullptr;
1325 83178 : t->fake_input_sockets.clear();
1326 83178 : t->set_outbuffers_allocated(false);
1327 :
1328 300396 : for (auto s : this->sockets)
1329 : {
1330 217218 : void* dataptr = nullptr;
1331 217218 : if (s->get_type() == socket_t::SOUT)
1332 : {
1333 136399 : if (s->get_name() == "status")
1334 : {
1335 83178 : dataptr = (void*)t->status.data();
1336 : }
1337 : }
1338 80819 : else if (s->get_type() == socket_t::SIN || s->get_type() == socket_t::SFWD)
1339 80819 : dataptr = s->_get_dataptr();
1340 :
1341 : // No need to allocate memory when cloning
1342 217218 : const std::pair<size_t, size_t> databytes_per_dim = { s->get_n_rows(), s->get_databytes() / s->get_n_rows() };
1343 : auto s_new = std::shared_ptr<Socket>(
1344 217218 : new Socket(*t, s->get_name(), s->get_datatype(), databytes_per_dim, s->get_type(), s->is_fast(), dataptr));
1345 217218 : t->sockets.push_back(s_new);
1346 :
1347 217218 : if (s_new->get_type() == socket_t::SIN || s_new->get_type() == socket_t::SFWD)
1348 80819 : t->last_input_socket = s_new.get();
1349 217218 : }
1350 :
1351 83178 : if (tools::Buffer_allocator::get_task_autoalloc()) t->allocate_outbuffers();
1352 :
1353 83178 : return t;
1354 : }
1355 :
1356 : void
1357 10438 : Task::_bind(Socket& s_out, const int priority)
1358 : {
1359 : // check if the 's_out' socket is already used for an other fake input socket
1360 10438 : bool already_bound = false;
1361 10438 : for (auto& fsi : this->fake_input_sockets)
1362 0 : if (&fsi->get_bound_socket() == &s_out)
1363 : {
1364 0 : already_bound = true;
1365 0 : break;
1366 : }
1367 :
1368 : // check if the 's_out' socket is already used for an other read input/fwd socket
1369 10438 : if (!already_bound)
1370 31831 : for (auto& s : this->sockets)
1371 21393 : if (s->get_type() == socket_t::SIN || s->get_type() == socket_t::SFWD)
1372 : {
1373 : try // because 's->get_bound_socket()' can throw if s->bound_socket == 'nullptr'
1374 : {
1375 1891 : if (&s->get_bound_socket() == &s_out)
1376 : {
1377 0 : already_bound = true;
1378 0 : break;
1379 : }
1380 : }
1381 864 : catch (...)
1382 : {
1383 864 : }
1384 : }
1385 :
1386 : // if the s_out socket is not already bound, then create a new fake input socket
1387 10438 : if (!already_bound)
1388 : {
1389 10438 : this->fake_input_sockets.push_back(
1390 20876 : std::shared_ptr<Socket>(new Socket(*this,
1391 20876 : "fake" + std::to_string(this->fake_input_sockets.size()),
1392 10438 : s_out.get_datatype(),
1393 10438 : s_out.get_databytes(),
1394 : socket_t::SIN,
1395 10438 : this->is_fast())));
1396 10438 : this->fake_input_sockets.back()->_bind(s_out, priority);
1397 10438 : this->last_input_socket = this->fake_input_sockets.back().get();
1398 10438 : this->n_input_sockets++;
1399 : }
1400 10438 : }
1401 :
1402 : void
1403 0 : Task::bind(Socket& s_out, const int priority)
1404 : {
1405 : #ifdef SPU_SHOW_DEPRECATED
1406 : std::clog << rang::tag::warning << "Deprecated: 'Task::bind()' should be replaced by 'Task::operator='."
1407 : << std::endl;
1408 : #ifdef SPU_STACKTRACE
1409 : #ifdef SPU_COLORS
1410 : bool enable_color = true;
1411 : #else
1412 : bool enable_color = false;
1413 : #endif
1414 : cpptrace::generate_trace().print(std::clog, enable_color);
1415 : #endif
1416 : #endif
1417 0 : this->_bind(s_out, priority);
1418 0 : }
1419 :
1420 : void
1421 0 : Task::_bind(Task& t_out, const int priority)
1422 : {
1423 0 : this->_bind(*t_out.sockets.back(), priority);
1424 0 : }
1425 :
1426 : void
1427 0 : Task::bind(Task& t_out, const int priority)
1428 : {
1429 : #ifdef SPU_SHOW_DEPRECATED
1430 : std::clog << rang::tag::warning << "Deprecated: 'Task::bind()' should be replaced by 'Task::operator='."
1431 : << std::endl;
1432 : #ifdef SPU_STACKTRACE
1433 : #ifdef SPU_COLORS
1434 : bool enable_color = true;
1435 : #else
1436 : bool enable_color = false;
1437 : #endif
1438 : cpptrace::generate_trace().print(std::clog, enable_color);
1439 : #endif
1440 : #endif
1441 0 : this->_bind(t_out, priority);
1442 0 : }
1443 :
1444 : void
1445 8284 : Task::operator=(Socket& s_out)
1446 : {
1447 : #ifndef SPU_FAST
1448 8284 : if (s_out.get_type() == socket_t::SOUT || s_out.get_type() == socket_t::SFWD)
1449 : #endif
1450 8284 : this->_bind(s_out);
1451 : #ifndef SPU_FAST
1452 : else
1453 : {
1454 0 : std::stringstream message;
1455 : message << "'s_out' should be and output socket ("
1456 0 : << "'s_out.datatype' = " << type_to_string[s_out.get_datatype()] << ", "
1457 0 : << "'s_out.name' = " << s_out.get_name() << ", "
1458 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1459 0 : << "'s_out.type' = " << (s_out.get_type() == socket_t::SIN ? "SIN" : "SOUT") << ", "
1460 0 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ").";
1461 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1462 0 : }
1463 : #endif
1464 8284 : }
1465 :
1466 : void
1467 342 : Task::operator=(Task& t_out)
1468 : {
1469 342 : (*this) = *t_out.sockets.back();
1470 342 : }
1471 :
1472 : size_t
1473 58077 : Task::unbind(Socket& s_out)
1474 : {
1475 58077 : if (this->fake_input_sockets.size())
1476 : {
1477 7239 : size_t i = 0;
1478 7305 : for (auto& fsi : this->fake_input_sockets)
1479 : {
1480 7239 : if (&fsi->get_bound_socket() == &s_out)
1481 : {
1482 7173 : const auto pos = fsi->unbind(s_out);
1483 7173 : if (this->last_input_socket == fsi.get()) this->last_input_socket = nullptr;
1484 7173 : this->fake_input_sockets.erase(this->fake_input_sockets.begin() + i);
1485 7173 : this->n_input_sockets--;
1486 7173 : if (this->fake_input_sockets.size() && this->last_input_socket == nullptr)
1487 0 : this->last_input_socket = this->fake_input_sockets.back().get();
1488 7173 : return pos;
1489 : }
1490 66 : i++;
1491 : }
1492 :
1493 66 : std::stringstream message;
1494 : message << "'s_out' is not bound the this task ("
1495 132 : << "'s_out.datatype' = " << type_to_string[s_out.datatype] << ", "
1496 66 : << "'s_out.name' = " << s_out.get_name() << ", "
1497 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1498 66 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ", "
1499 66 : << "'task.name' = " << this->get_name() << ", "
1500 132 : << "'task.module.name' = " << this->get_module().get_custom_name() << ").";
1501 66 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1502 66 : }
1503 : else
1504 : {
1505 50838 : std::stringstream message;
1506 : message << "This task does not have fake input socket to unbind ("
1507 101676 : << "'s_out.datatype' = " << type_to_string[s_out.datatype] << ", "
1508 50838 : << "'s_out.name' = " << s_out.get_name() << ", "
1509 0 : << "'s_out.task.name' = " << s_out.task.get_name() << ", "
1510 50838 : << "'s_out.task.module.name' = " << s_out.task.get_module().get_custom_name() << ", "
1511 50838 : << "'task.name' = " << this->get_name() << ", "
1512 101676 : << "'task.module.name' = " << this->get_module().get_custom_name() << ").";
1513 50838 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1514 50838 : }
1515 : }
1516 :
1517 : size_t
1518 278 : Task::unbind(Task& t_out)
1519 : {
1520 278 : return this->unbind(*t_out.sockets.back());
1521 : }
1522 :
1523 : size_t
1524 3125 : Task::get_n_static_input_sockets() const
1525 : {
1526 3125 : size_t n = 0;
1527 11579 : for (auto& s : this->sockets)
1528 8454 : if (s->get_type() == socket_t::SIN && s->_get_dataptr() != nullptr && s->bound_socket == nullptr) n++;
1529 3125 : return n;
1530 : }
1531 :
1532 : bool
1533 0 : Task::is_stateless() const
1534 : {
1535 0 : return this->get_module().is_stateless();
1536 : }
1537 :
1538 : bool
1539 0 : Task::is_stateful() const
1540 : {
1541 0 : return this->get_module().is_stateful();
1542 : }
1543 :
1544 : bool
1545 2490 : Task::is_replicable() const
1546 : {
1547 2490 : return this->replicable;
1548 : }
1549 :
1550 : void
1551 75 : Task::set_replicability(const bool replicable)
1552 : {
1553 75 : if (replicable && !this->module->is_clonable())
1554 : {
1555 0 : std::stringstream message;
1556 : message << "The replicability of this task cannot be set to true because its corresponding module is not "
1557 0 : << "clonable (task.name = '" << this->get_name() << "', module.name = '"
1558 0 : << this->get_module().get_name() << "').";
1559 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
1560 0 : }
1561 75 : this->replicable = replicable;
1562 : // this->replicable = replicable ? this->module->is_clonable() : false;
1563 75 : }
1564 :
1565 : void
1566 151391 : Task::set_outbuffers_allocated(const bool outbuffers_allocated)
1567 : {
1568 151391 : this->outbuffers_allocated = outbuffers_allocated;
1569 151391 : }
1570 :
1571 : // ==================================================================================== explicit template instantiation
1572 : template size_t
1573 : Task::create_2d_socket_in<int8_t>(const std::string&, const size_t, const size_t);
1574 : template size_t
1575 : Task::create_2d_socket_in<uint8_t>(const std::string&, const size_t, const size_t);
1576 : template size_t
1577 : Task::create_2d_socket_in<int16_t>(const std::string&, const size_t, const size_t);
1578 : template size_t
1579 : Task::create_2d_socket_in<uint16_t>(const std::string&, const size_t, const size_t);
1580 : template size_t
1581 : Task::create_2d_socket_in<int32_t>(const std::string&, const size_t, const size_t);
1582 : template size_t
1583 : Task::create_2d_socket_in<uint32_t>(const std::string&, const size_t, const size_t);
1584 : template size_t
1585 : Task::create_2d_socket_in<int64_t>(const std::string&, const size_t, const size_t);
1586 : template size_t
1587 : Task::create_2d_socket_in<uint64_t>(const std::string&, const size_t, const size_t);
1588 : template size_t
1589 : Task::create_2d_socket_in<float>(const std::string&, const size_t, const size_t);
1590 : template size_t
1591 : Task::create_2d_socket_in<double>(const std::string&, const size_t, const size_t);
1592 :
1593 : template size_t
1594 : Task::create_2d_socket_out<int8_t>(const std::string&, const size_t, const size_t, const bool);
1595 : template size_t
1596 : Task::create_2d_socket_out<uint8_t>(const std::string&, const size_t, const size_t, const bool);
1597 : template size_t
1598 : Task::create_2d_socket_out<int16_t>(const std::string&, const size_t, const size_t, const bool);
1599 : template size_t
1600 : Task::create_2d_socket_out<uint16_t>(const std::string&, const size_t, const size_t, const bool);
1601 : template size_t
1602 : Task::create_2d_socket_out<int32_t>(const std::string&, const size_t, const size_t, const bool);
1603 : template size_t
1604 : Task::create_2d_socket_out<uint32_t>(const std::string&, const size_t, const size_t, const bool);
1605 : template size_t
1606 : Task::create_2d_socket_out<int64_t>(const std::string&, const size_t, const size_t, const bool);
1607 : template size_t
1608 : Task::create_2d_socket_out<uint64_t>(const std::string&, const size_t, const size_t, const bool);
1609 : template size_t
1610 : Task::create_2d_socket_out<float>(const std::string&, const size_t, const size_t, const bool);
1611 : template size_t
1612 : Task::create_2d_socket_out<double>(const std::string&, const size_t, const size_t, const bool);
1613 :
1614 : template size_t
1615 : Task::create_2d_socket_fwd<int8_t>(const std::string&, const size_t, const size_t);
1616 : template size_t
1617 : Task::create_2d_socket_fwd<uint8_t>(const std::string&, const size_t, const size_t);
1618 : template size_t
1619 : Task::create_2d_socket_fwd<int16_t>(const std::string&, const size_t, const size_t);
1620 : template size_t
1621 : Task::create_2d_socket_fwd<uint16_t>(const std::string&, const size_t, const size_t);
1622 : template size_t
1623 : Task::create_2d_socket_fwd<int32_t>(const std::string&, const size_t, const size_t);
1624 : template size_t
1625 : Task::create_2d_socket_fwd<uint32_t>(const std::string&, const size_t, const size_t);
1626 : template size_t
1627 : Task::create_2d_socket_fwd<int64_t>(const std::string&, const size_t, const size_t);
1628 : template size_t
1629 : Task::create_2d_socket_fwd<uint64_t>(const std::string&, const size_t, const size_t);
1630 : template size_t
1631 : Task::create_2d_socket_fwd<float>(const std::string&, const size_t, const size_t);
1632 : template size_t
1633 : Task::create_2d_socket_fwd<double>(const std::string&, const size_t, const size_t);
1634 : // ==================================================================================== explicit template instantiation
|