Line data Source code
1 : #include <sstream>
2 : #include <string>
3 :
4 : #include "Module/Stateful/Set/Set.hpp"
5 : #include "Runtime/Sequence/Sequence.hpp"
6 : #include "Tools/Exception/exception.hpp"
7 :
8 : using namespace spu;
9 : using namespace spu::module;
10 :
11 1 : Set::Set(runtime::Sequence& sequence)
12 : : Stateful()
13 1 : , sequence_extern(&sequence)
14 : {
15 1 : this->init();
16 1 : }
17 :
18 0 : Set::Set(const runtime::Sequence& sequence)
19 : : Stateful()
20 0 : , sequence_cloned(sequence.clone())
21 0 : , sequence_extern(nullptr)
22 : {
23 0 : this->init();
24 0 : }
25 :
26 : void
27 1 : Set::init()
28 : {
29 1 : const std::string name = "Set";
30 1 : this->set_name(name);
31 1 : this->set_short_name(name);
32 1 : this->set_single_wave(true);
33 :
34 1 : auto& sequence = this->get_sequence();
35 :
36 1 : if (sequence.get_n_threads() != 1)
37 : {
38 0 : std::stringstream message;
39 0 : message << "'sequence.get_n_threads()' has to be equal to 1 ('sequence.get_n_threads()' = "
40 0 : << sequence.get_n_threads() << ").";
41 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
42 0 : }
43 :
44 1 : auto& p = this->create_task("exec");
45 :
46 1 : auto& firsts = sequence.get_firsts_tasks()[0];
47 2 : for (auto& first : firsts)
48 4 : for (auto& s : first->sockets)
49 : {
50 3 : if (s->get_type() == runtime::socket_t::SIN)
51 : {
52 1 : if (s->get_datatype() == typeid(int8_t))
53 0 : this->template create_socket_in<int8_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
54 1 : else if (s->get_datatype() == typeid(uint8_t))
55 1 : this->template create_socket_in<uint8_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
56 0 : else if (s->get_datatype() == typeid(int16_t))
57 0 : this->template create_socket_in<int16_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
58 0 : else if (s->get_datatype() == typeid(uint16_t))
59 0 : this->template create_socket_in<uint16_t>(
60 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
61 0 : else if (s->get_datatype() == typeid(int32_t))
62 0 : this->template create_socket_in<int32_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
63 0 : else if (s->get_datatype() == typeid(uint32_t))
64 0 : this->template create_socket_in<uint32_t>(
65 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
66 0 : else if (s->get_datatype() == typeid(int64_t))
67 0 : this->template create_socket_in<int64_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
68 0 : else if (s->get_datatype() == typeid(uint64_t))
69 0 : this->template create_socket_in<uint64_t>(
70 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
71 0 : else if (s->get_datatype() == typeid(float))
72 0 : this->template create_socket_in<float>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
73 0 : else if (s->get_datatype() == typeid(double))
74 0 : this->template create_socket_in<double>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
75 : }
76 2 : else if (s->get_type() == runtime::socket_t::SFWD)
77 : {
78 0 : std::stringstream message;
79 0 : message << "Forward socket is not supported yet :-(.";
80 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
81 0 : }
82 : }
83 1 : auto& lasts = sequence.get_lasts_tasks()[0];
84 2 : for (auto& last : lasts)
85 4 : for (auto& s : last->sockets)
86 : {
87 3 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
88 : {
89 1 : if (s->get_datatype() == typeid(int8_t))
90 0 : this->template create_socket_out<int8_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
91 1 : else if (s->get_datatype() == typeid(uint8_t))
92 1 : this->template create_socket_out<uint8_t>(
93 1 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
94 0 : else if (s->get_datatype() == typeid(int16_t))
95 0 : this->template create_socket_out<int16_t>(
96 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
97 0 : else if (s->get_datatype() == typeid(uint16_t))
98 0 : this->template create_socket_out<uint16_t>(
99 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
100 0 : else if (s->get_datatype() == typeid(int32_t))
101 0 : this->template create_socket_out<int32_t>(
102 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
103 0 : else if (s->get_datatype() == typeid(uint32_t))
104 0 : this->template create_socket_out<uint32_t>(
105 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
106 0 : else if (s->get_datatype() == typeid(int64_t))
107 0 : this->template create_socket_out<int64_t>(
108 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
109 0 : else if (s->get_datatype() == typeid(uint64_t))
110 0 : this->template create_socket_out<uint64_t>(
111 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
112 0 : else if (s->get_datatype() == typeid(float))
113 0 : this->template create_socket_out<float>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
114 0 : else if (s->get_datatype() == typeid(double))
115 0 : this->template create_socket_out<double>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
116 : }
117 2 : else if (s->get_type() == runtime::socket_t::SFWD)
118 : {
119 0 : std::stringstream message;
120 0 : message << "Forward socket is not supported yet :-(.";
121 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
122 0 : }
123 : }
124 :
125 1 : size_t sid = 0;
126 2 : for (auto& last : lasts)
127 4 : for (auto& s : last->sockets)
128 : {
129 3 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
130 : {
131 2 : while (p.sockets[sid]->get_type() != runtime::socket_t::SOUT)
132 1 : sid++;
133 1 : p.sockets[sid++]->_bind(*s); // out to out socket binding = black magic
134 : }
135 : }
136 :
137 1 : this->create_codelet(p,
138 56 : [](Module& m, runtime::Task& t, const size_t /*frame_id*/) -> int
139 : {
140 56 : auto& ss = static_cast<Set&>(m);
141 :
142 56 : auto& firsts = ss.get_sequence().get_firsts_tasks()[0];
143 56 : size_t sid = 0;
144 86 : for (auto& first : firsts)
145 181 : for (auto& s : first->sockets)
146 : {
147 143 : if (s->get_type() == runtime::socket_t::SIN)
148 : {
149 50 : while (t.sockets[sid]->get_type() != runtime::socket_t::SIN)
150 0 : sid++;
151 44 : (*s) = t.sockets[sid++]->_get_dataptr();
152 : }
153 : }
154 :
155 : // execute all frames sequentially
156 40 : ss.get_sequence().exec_seq();
157 :
158 61 : return runtime::status_t::SUCCESS;
159 : });
160 1 : }
161 :
162 : runtime::Sequence&
163 294 : Set::get_sequence()
164 : {
165 294 : if (this->sequence_extern)
166 6 : return *this->sequence_extern;
167 : else
168 288 : return *this->sequence_cloned;
169 : }
170 :
171 : Set*
172 48 : Set::clone() const
173 : {
174 48 : auto m = new Set(*this);
175 48 : m->deep_copy(*this);
176 48 : return m;
177 : }
178 :
179 : void
180 48 : Set::deep_copy(const Set& m)
181 : {
182 48 : Stateful::deep_copy(m);
183 48 : if (m.sequence_cloned != nullptr)
184 0 : this->sequence_cloned.reset(m.sequence_cloned->clone());
185 : else
186 : {
187 48 : this->sequence_cloned.reset(m.sequence_extern->clone());
188 48 : this->sequence_extern = nullptr;
189 : }
190 :
191 48 : auto& lasts = this->get_sequence().get_lasts_tasks()[0];
192 :
193 : try
194 : {
195 50 : auto& p = (*this)("exec");
196 :
197 47 : size_t sid = 0;
198 94 : for (auto& last : lasts)
199 188 : for (auto& s : last->sockets)
200 : {
201 141 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
202 : {
203 94 : while (p.sockets[sid]->get_type() != runtime::socket_t::SOUT)
204 47 : sid++;
205 47 : p.sockets[sid++]->_bind(*s); // out to out socket binding = black magic
206 : }
207 : }
208 : }
209 1 : catch (tools::invalid_argument&)
210 : {
211 : /* this is a hack: do nothing, we went there because of trying to determine if the set is replicable */
212 1 : }
213 48 : }
214 :
215 : void
216 97 : Set::set_n_frames(const size_t n_frames)
217 : {
218 97 : const auto old_n_frames = this->get_n_frames();
219 97 : if (old_n_frames != n_frames)
220 : {
221 96 : auto& p = *this->tasks[0];
222 96 : auto& lasts = this->get_sequence().get_lasts_tasks()[0];
223 96 : size_t sid = 0;
224 192 : for (auto& last : lasts)
225 384 : for (auto& s : last->sockets)
226 : {
227 288 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
228 : {
229 192 : while (p.sockets[sid]->get_type() != runtime::socket_t::SOUT)
230 96 : sid++;
231 96 : p.sockets[sid++]->unbind(*s);
232 : }
233 : }
234 :
235 96 : Module::set_n_frames(n_frames);
236 :
237 96 : if (this->sequence_extern)
238 2 : this->sequence_extern->set_n_frames(n_frames);
239 : else
240 94 : this->sequence_cloned->set_n_frames(n_frames);
241 :
242 96 : sid = 0;
243 192 : for (auto& last : lasts)
244 384 : for (auto& s : last->sockets)
245 : {
246 288 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
247 : {
248 192 : while (p.sockets[sid]->get_type() != runtime::socket_t::SOUT)
249 96 : sid++;
250 96 : p.sockets[sid++]->_bind(*s); // out to out socket binding = black magic
251 : }
252 : }
253 : }
254 97 : }
|