Line data Source code
1 : #include <cmath>
2 : #include <sstream>
3 : #include <string>
4 :
5 : #include "Module/Stateful/Adaptor/Adaptor_m_to_n.hpp"
6 : #include "Tools/Exception/exception.hpp"
7 : #include "Tools/compute_bytes.h"
8 :
9 : namespace spu
10 : {
11 : namespace module
12 : {
13 :
14 269 : Adaptor_m_to_n::Adaptor_m_to_n(const std::vector<size_t>& n_elmts,
15 : const std::vector<std::type_index>& datatype,
16 : const size_t buffer_size,
17 269 : const bool active_waiting)
18 : : Stateful()
19 269 : , n_elmts(n_elmts)
20 269 : , n_bytes(tools::compute_bytes(n_elmts, datatype))
21 269 : , datatype(datatype)
22 269 : , buffer_size(buffer_size)
23 269 : , n_sockets(n_elmts.size())
24 269 : , buffer(new std::vector<std::vector<std::vector<int8_t*>>>(
25 : 1,
26 538 : std::vector<std::vector<int8_t*>>(n_sockets, std::vector<int8_t*>(buffer_size))))
27 269 : , first(new std::vector<uint32_t>(1000))
28 269 : , last(new std::vector<uint32_t>(1000))
29 269 : , counter(new std::vector<std::atomic<uint32_t>>(1000))
30 269 : , waiting_canceled(new std::atomic<bool>(false))
31 269 : , no_copy_pull(false)
32 269 : , no_copy_push(false)
33 269 : , active_waiting(active_waiting)
34 269 : , cnd_push(new std::vector<std::condition_variable>(1000))
35 269 : , mtx_push(new std::vector<std::mutex>(1000))
36 269 : , cnd_pull(new std::vector<std::condition_variable>(1000))
37 269 : , mtx_pull(new std::vector<std::mutex>(1000))
38 269 : , tid_push(0)
39 269 : , tid_pull(0)
40 269 : , n_pushers(new size_t(1))
41 269 : , n_pullers(new size_t(1))
42 269 : , n_clones(new size_t(0))
43 269 : , buffers_allocated(new bool(false))
44 269 : , cloned(false)
45 269 : , cur_push_id(0)
46 538 : , cur_pull_id(0)
47 : {
48 269 : const std::string name = "Adaptor_m_to_n";
49 269 : this->set_name(name);
50 269 : this->set_short_name(name);
51 :
52 269 : if (buffer_size == 0)
53 : {
54 0 : std::stringstream message;
55 0 : message << "'buffer_size' has to be greater than 0 ('buffer_size' = " << buffer_size << ").";
56 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
57 0 : }
58 :
59 269 : if (n_elmts.size() == 0)
60 : {
61 0 : std::stringstream message;
62 0 : message << "'n_elmts.size()' has to be greater than 0 ('n_elmts.size()' = " << n_elmts.size() << ").";
63 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
64 0 : }
65 :
66 727 : for (size_t e = 0; e < n_elmts.size(); e++)
67 : {
68 458 : if (n_elmts[e] == 0)
69 : {
70 0 : std::stringstream message;
71 0 : message << "'n_elmts[e]' has to be greater than 0 ('e' = " << e << ", 'n_elmts[e]' = " << n_elmts[e]
72 0 : << ").";
73 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
74 0 : }
75 : }
76 :
77 269 : this->init();
78 269 : }
79 :
80 : Adaptor_m_to_n::Adaptor_m_to_n(const size_t n_elmts,
81 : const std::type_index datatype,
82 : const size_t buffer_size,
83 : const bool active_waiting)
84 : : Adaptor_m_to_n(std::vector<size_t>(1, n_elmts),
85 : std::vector<std::type_index>(1, datatype),
86 : buffer_size,
87 : active_waiting)
88 : {
89 : }
90 :
91 : void
92 269 : Adaptor_m_to_n::init()
93 : {
94 269 : const std::string name = "Adaptor_m_to_n";
95 269 : this->set_name(name);
96 269 : this->set_short_name(name);
97 269 : this->set_single_wave(true);
98 :
99 269 : auto& p1 = this->create_task("push");
100 269 : std::vector<size_t> p1s_in;
101 727 : for (size_t s = 0; s < this->n_sockets; s++)
102 458 : p1s_in.push_back(this->create_socket_in(p1, "in" + std::to_string(s), this->n_elmts[s], this->datatype[s]));
103 :
104 269 : this->create_codelet(p1,
105 587810 : [p1s_in](Module& m, runtime::Task& t, const size_t frame_id) -> int
106 : {
107 587810 : auto& adp = static_cast<Adaptor_m_to_n&>(m);
108 587810 : if (adp.is_no_copy_push())
109 : {
110 590833 : adp.wait_push();
111 : // for debug mode coherence
112 1644745 : for (size_t s = 0; s < t.sockets.size() - 1; s++)
113 988994 : t.sockets[s]->dataptr = adp.get_empty_buffer(s);
114 : }
115 : else
116 : {
117 0 : std::vector<const int8_t*> sockets_dataptr(p1s_in.size());
118 0 : for (size_t s = 0; s < p1s_in.size(); s++)
119 0 : sockets_dataptr[s] = t[p1s_in[s]].get_dataptr<const int8_t>();
120 0 : adp.push(sockets_dataptr, frame_id);
121 0 : }
122 616816 : return runtime::status_t::SUCCESS;
123 : });
124 :
125 269 : auto& p2 = this->create_task("pull");
126 269 : std::vector<size_t> p2s_out;
127 727 : for (size_t s = 0; s < this->n_sockets; s++)
128 458 : p2s_out.push_back(this->create_socket_out(p2, "out" + std::to_string(s), this->n_elmts[s], this->datatype[s]));
129 :
130 269 : this->create_codelet(p2,
131 540175 : [p2s_out](Module& m, runtime::Task& t, const size_t frame_id) -> int
132 : {
133 540175 : auto& adp = static_cast<Adaptor_m_to_n&>(m);
134 540175 : if (adp.is_no_copy_pull())
135 : {
136 540735 : adp.wait_pull();
137 : // for debug mode coherence
138 1450145 : for (size_t s = 0; s < t.sockets.size() - 1; s++)
139 848809 : t.sockets[s]->_bind(adp.get_filled_buffer(s));
140 : }
141 : else
142 : {
143 0 : std::vector<int8_t*> sockets_dataptr(p2s_out.size());
144 0 : for (size_t s = 0; s < p2s_out.size(); s++)
145 0 : sockets_dataptr[s] = t[p2s_out[s]].get_dataptr<int8_t>();
146 0 : adp.pull(sockets_dataptr, frame_id);
147 0 : }
148 549930 : return runtime::status_t::SUCCESS;
149 : });
150 269 : }
151 :
152 : size_t
153 : Adaptor_m_to_n::get_n_elmts(const size_t sid) const
154 : {
155 : return this->n_elmts[sid];
156 : }
157 :
158 : size_t
159 : Adaptor_m_to_n::get_n_bytes(const size_t sid) const
160 : {
161 : return this->n_bytes[sid];
162 : }
163 :
164 : std::type_index
165 : Adaptor_m_to_n::get_datatype(const size_t sid) const
166 : {
167 : return this->datatype[sid];
168 : }
169 :
170 : bool
171 2522450 : Adaptor_m_to_n::is_empty(const size_t id)
172 : {
173 2522450 : return (*this->counter)[id] == this->buffer_size;
174 : }
175 :
176 : bool
177 1614038 : Adaptor_m_to_n::is_full(const size_t id)
178 : {
179 1614038 : return (*this->counter)[id] == 0;
180 : }
181 :
182 : size_t
183 : Adaptor_m_to_n::n_fill_slots(const size_t id)
184 : {
185 : return this->buffer_size - (*this->counter)[id];
186 : }
187 :
188 : size_t
189 : Adaptor_m_to_n::n_free_slots(const size_t id)
190 : {
191 : return (*this->counter)[id];
192 : }
193 :
194 : }
195 : }
|