Line data Source code
1 : #include <cmath>
2 : #include <sstream>
3 : #include <string>
4 :
5 : #include "Module/Stateful/Adaptor/Adaptor_m_to_n.hpp"
6 : #include "Tools/Exception/exception.hpp"
7 : #include "Tools/compute_bytes.h"
8 :
9 : namespace spu
10 : {
11 : namespace module
12 : {
13 :
14 242 : Adaptor_m_to_n::Adaptor_m_to_n(const std::vector<size_t>& n_elmts,
15 : const std::vector<std::type_index>& datatype,
16 : const size_t buffer_size,
17 242 : const bool active_waiting)
18 : : Stateful()
19 242 : , n_elmts(n_elmts)
20 242 : , n_bytes(tools::compute_bytes(n_elmts, datatype))
21 242 : , datatype(datatype)
22 242 : , buffer_size(buffer_size)
23 242 : , n_sockets(n_elmts.size())
24 242 : , buffer(new std::vector<std::vector<std::vector<int8_t*>>>(
25 : 1,
26 484 : std::vector<std::vector<int8_t*>>(n_sockets, std::vector<int8_t*>(buffer_size))))
27 242 : , first(new std::vector<uint32_t>(1000))
28 242 : , last(new std::vector<uint32_t>(1000))
29 242 : , counter(new std::vector<std::atomic<uint32_t>>(1000))
30 242 : , waiting_canceled(new std::atomic<bool>(false))
31 242 : , no_copy_pull(false)
32 242 : , no_copy_push(false)
33 242 : , active_waiting(active_waiting)
34 242 : , cnd_push(new std::vector<std::condition_variable>(1000))
35 242 : , mtx_push(new std::vector<std::mutex>(1000))
36 242 : , cnd_pull(new std::vector<std::condition_variable>(1000))
37 242 : , mtx_pull(new std::vector<std::mutex>(1000))
38 242 : , tid_push(0)
39 242 : , tid_pull(0)
40 242 : , n_pushers(new size_t(1))
41 242 : , n_pullers(new size_t(1))
42 242 : , n_clones(new size_t(0))
43 242 : , buffers_allocated(new bool(false))
44 242 : , cloned(false)
45 242 : , cur_push_id(0)
46 484 : , cur_pull_id(0)
47 : {
48 242 : const std::string name = "Adaptor_m_to_n";
49 242 : this->set_name(name);
50 242 : this->set_short_name(name);
51 :
52 242 : if (buffer_size == 0)
53 : {
54 0 : std::stringstream message;
55 0 : message << "'buffer_size' has to be greater than 0 ('buffer_size' = " << buffer_size << ").";
56 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
57 0 : }
58 :
59 242 : if (n_elmts.size() == 0)
60 : {
61 0 : std::stringstream message;
62 0 : message << "'n_elmts.size()' has to be greater than 0 ('n_elmts.size()' = " << n_elmts.size() << ").";
63 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
64 0 : }
65 :
66 652 : for (size_t e = 0; e < n_elmts.size(); e++)
67 : {
68 410 : if (n_elmts[e] == 0)
69 : {
70 0 : std::stringstream message;
71 0 : message << "'n_elmts[e]' has to be greater than 0 ('e' = " << e << ", 'n_elmts[e]' = " << n_elmts[e]
72 0 : << ").";
73 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
74 0 : }
75 : }
76 :
77 242 : this->init();
78 242 : }
79 :
80 : Adaptor_m_to_n::Adaptor_m_to_n(const size_t n_elmts,
81 : const std::type_index datatype,
82 : const size_t buffer_size,
83 : const bool active_waiting)
84 : : Adaptor_m_to_n(std::vector<size_t>(1, n_elmts),
85 : std::vector<std::type_index>(1, datatype),
86 : buffer_size,
87 : active_waiting)
88 : {
89 : }
90 :
91 : void
92 242 : Adaptor_m_to_n::init()
93 : {
94 242 : const std::string name = "Adaptor_m_to_n";
95 242 : this->set_name(name);
96 242 : this->set_short_name(name);
97 242 : this->set_single_wave(true);
98 :
99 242 : auto& p1 = this->create_task("push");
100 242 : std::vector<size_t> p1s_in;
101 652 : for (size_t s = 0; s < this->n_sockets; s++)
102 410 : p1s_in.push_back(this->create_socket_in(p1, "in" + std::to_string(s), this->n_elmts[s], this->datatype[s]));
103 :
104 242 : this->create_codelet(p1,
105 543711 : [p1s_in](Module& m, runtime::Task& t, const size_t frame_id) -> int
106 : {
107 543711 : auto& adp = static_cast<Adaptor_m_to_n&>(m);
108 543711 : if (adp.is_no_copy_push())
109 : {
110 545065 : adp.wait_push();
111 : // for debug mode coherence
112 1357446 : for (size_t s = 0; s < t.sockets.size() - 1; s++)
113 771971 : t.sockets[s]->dataptr = adp.get_empty_buffer(s);
114 : }
115 : else
116 : {
117 0 : std::vector<const int8_t*> sockets_dataptr(p1s_in.size());
118 0 : for (size_t s = 0; s < p1s_in.size(); s++)
119 0 : sockets_dataptr[s] = t[p1s_in[s]].get_dataptr<const int8_t>();
120 0 : adp.push(sockets_dataptr, frame_id);
121 0 : }
122 565157 : return runtime::status_t::SUCCESS;
123 : });
124 :
125 242 : auto& p2 = this->create_task("pull");
126 242 : std::vector<size_t> p2s_out;
127 652 : for (size_t s = 0; s < this->n_sockets; s++)
128 410 : p2s_out.push_back(this->create_socket_out(p2, "out" + std::to_string(s), this->n_elmts[s], this->datatype[s]));
129 :
130 242 : this->create_codelet(p2,
131 535705 : [p2s_out](Module& m, runtime::Task& t, const size_t frame_id) -> int
132 : {
133 535705 : auto& adp = static_cast<Adaptor_m_to_n&>(m);
134 535705 : if (adp.is_no_copy_pull())
135 : {
136 536161 : adp.wait_pull();
137 : // for debug mode coherence
138 1269381 : for (size_t s = 0; s < t.sockets.size() - 1; s++)
139 709882 : t.sockets[s]->_bind(adp.get_filled_buffer(s));
140 : }
141 : else
142 : {
143 0 : std::vector<int8_t*> sockets_dataptr(p2s_out.size());
144 0 : for (size_t s = 0; s < p2s_out.size(); s++)
145 0 : sockets_dataptr[s] = t[p2s_out[s]].get_dataptr<int8_t>();
146 0 : adp.pull(sockets_dataptr, frame_id);
147 0 : }
148 532178 : return runtime::status_t::SUCCESS;
149 : });
150 242 : }
151 :
152 : size_t
153 : Adaptor_m_to_n::get_n_elmts(const size_t sid) const
154 : {
155 : return this->n_elmts[sid];
156 : }
157 :
158 : size_t
159 : Adaptor_m_to_n::get_n_bytes(const size_t sid) const
160 : {
161 : return this->n_bytes[sid];
162 : }
163 :
164 : std::type_index
165 : Adaptor_m_to_n::get_datatype(const size_t sid) const
166 : {
167 : return this->datatype[sid];
168 : }
169 :
170 : bool
171 2682412 : Adaptor_m_to_n::is_empty(const size_t id)
172 : {
173 2682412 : return (*this->counter)[id] == this->buffer_size;
174 : }
175 :
176 : bool
177 1717998 : Adaptor_m_to_n::is_full(const size_t id)
178 : {
179 1717998 : return (*this->counter)[id] == 0;
180 : }
181 :
182 : size_t
183 : Adaptor_m_to_n::n_fill_slots(const size_t id)
184 : {
185 : return this->buffer_size - (*this->counter)[id];
186 : }
187 :
188 : size_t
189 : Adaptor_m_to_n::n_free_slots(const size_t id)
190 : {
191 : return (*this->counter)[id];
192 : }
193 :
194 : }
195 : }
|