Line data Source code
1 : #include <sstream>
2 : #include <string>
3 :
4 : #include "Module/Stateful/Set/Set.hpp"
5 : #include "Runtime/Sequence/Sequence.hpp"
6 : #include "Tools/Exception/exception.hpp"
7 :
8 : using namespace spu;
9 : using namespace spu::module;
10 :
11 1 : Set::Set(runtime::Sequence& sequence)
12 : : Stateful()
13 1 : , sequence_extern(&sequence)
14 : {
15 1 : this->init();
16 1 : }
17 :
18 0 : Set::Set(const runtime::Sequence& sequence)
19 : : Stateful()
20 0 : , sequence_cloned(sequence.clone())
21 0 : , sequence_extern(nullptr)
22 : {
23 0 : this->init();
24 0 : }
25 :
26 : void
27 1 : Set::init()
28 : {
29 1 : const std::string name = "Set";
30 1 : this->set_name(name);
31 1 : this->set_short_name(name);
32 1 : this->set_single_wave(true);
33 :
34 1 : auto& sequence = this->get_sequence();
35 :
36 1 : if (sequence.get_n_threads() != 1)
37 : {
38 0 : std::stringstream message;
39 0 : message << "'sequence.get_n_threads()' has to be equal to 1 ('sequence.get_n_threads()' = "
40 0 : << sequence.get_n_threads() << ").";
41 0 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
42 0 : }
43 :
44 1 : auto& p = this->create_task("exec");
45 1 : p.set_autoalloc(true);
46 :
47 1 : auto& firsts = sequence.get_firsts_tasks()[0];
48 2 : for (auto& first : firsts)
49 4 : for (auto& s : first->sockets)
50 : {
51 3 : if (s->get_type() == runtime::socket_t::SIN)
52 : {
53 1 : if (s->get_datatype() == typeid(int8_t))
54 0 : this->template create_socket_in<int8_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
55 1 : else if (s->get_datatype() == typeid(uint8_t))
56 1 : this->template create_socket_in<uint8_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
57 0 : else if (s->get_datatype() == typeid(int16_t))
58 0 : this->template create_socket_in<int16_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
59 0 : else if (s->get_datatype() == typeid(uint16_t))
60 0 : this->template create_socket_in<uint16_t>(
61 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
62 0 : else if (s->get_datatype() == typeid(int32_t))
63 0 : this->template create_socket_in<int32_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
64 0 : else if (s->get_datatype() == typeid(uint32_t))
65 0 : this->template create_socket_in<uint32_t>(
66 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
67 0 : else if (s->get_datatype() == typeid(int64_t))
68 0 : this->template create_socket_in<int64_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
69 0 : else if (s->get_datatype() == typeid(uint64_t))
70 0 : this->template create_socket_in<uint64_t>(
71 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
72 0 : else if (s->get_datatype() == typeid(float))
73 0 : this->template create_socket_in<float>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
74 0 : else if (s->get_datatype() == typeid(double))
75 0 : this->template create_socket_in<double>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
76 : }
77 2 : else if (s->get_type() == runtime::socket_t::SFWD)
78 : {
79 0 : std::stringstream message;
80 0 : message << "Forward socket is not supported yet :-(.";
81 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
82 0 : }
83 : }
84 1 : auto& lasts = sequence.get_lasts_tasks()[0];
85 2 : for (auto& last : lasts)
86 4 : for (auto& s : last->sockets)
87 : {
88 3 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
89 : {
90 1 : if (s->get_datatype() == typeid(int8_t))
91 0 : this->template create_socket_out<int8_t>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
92 1 : else if (s->get_datatype() == typeid(uint8_t))
93 1 : this->template create_socket_out<uint8_t>(
94 1 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
95 0 : else if (s->get_datatype() == typeid(int16_t))
96 0 : this->template create_socket_out<int16_t>(
97 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
98 0 : else if (s->get_datatype() == typeid(uint16_t))
99 0 : this->template create_socket_out<uint16_t>(
100 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
101 0 : else if (s->get_datatype() == typeid(int32_t))
102 0 : this->template create_socket_out<int32_t>(
103 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
104 0 : else if (s->get_datatype() == typeid(uint32_t))
105 0 : this->template create_socket_out<uint32_t>(
106 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
107 0 : else if (s->get_datatype() == typeid(int64_t))
108 0 : this->template create_socket_out<int64_t>(
109 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
110 0 : else if (s->get_datatype() == typeid(uint64_t))
111 0 : this->template create_socket_out<uint64_t>(
112 0 : p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
113 0 : else if (s->get_datatype() == typeid(float))
114 0 : this->template create_socket_out<float>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
115 0 : else if (s->get_datatype() == typeid(double))
116 0 : this->template create_socket_out<double>(p, s->get_name(), s->get_n_elmts() / this->get_n_frames());
117 : }
118 2 : else if (s->get_type() == runtime::socket_t::SFWD)
119 : {
120 0 : std::stringstream message;
121 0 : message << "Forward socket is not supported yet :-(.";
122 0 : throw tools::runtime_error(__FILE__, __LINE__, __func__, message.str());
123 0 : }
124 : }
125 :
126 1 : size_t sid = 0;
127 2 : for (auto& last : lasts)
128 4 : for (auto& s : last->sockets)
129 : {
130 3 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
131 : {
132 2 : while (p.sockets[sid]->get_type() != runtime::socket_t::SOUT)
133 1 : sid++;
134 1 : p.sockets[sid++]->_bind(*s); // out to out socket binding = black magic
135 : }
136 : }
137 :
138 1 : this->create_codelet(p,
139 48 : [](Module& m, runtime::Task& t, const size_t frame_id) -> int
140 : {
141 48 : auto& ss = static_cast<Set&>(m);
142 :
143 48 : auto& firsts = ss.get_sequence().get_firsts_tasks()[0];
144 50 : size_t sid = 0;
145 81 : for (auto& first : firsts)
146 182 : for (auto& s : first->sockets)
147 : {
148 128 : if (s->get_type() == runtime::socket_t::SIN)
149 : {
150 45 : while (t.sockets[sid]->get_type() != runtime::socket_t::SIN)
151 0 : sid++;
152 42 : (*s) = t.sockets[sid++]->_get_dataptr();
153 : }
154 : }
155 :
156 : // execute all frames sequentially
157 42 : ss.get_sequence().exec_seq();
158 :
159 62 : return runtime::status_t::SUCCESS;
160 : });
161 1 : }
162 :
163 : runtime::Sequence&
164 282 : Set::get_sequence()
165 : {
166 282 : if (this->sequence_extern)
167 6 : return *this->sequence_extern;
168 : else
169 276 : return *this->sequence_cloned;
170 : }
171 :
172 : Set*
173 48 : Set::clone() const
174 : {
175 48 : auto m = new Set(*this);
176 48 : m->deep_copy(*this);
177 48 : return m;
178 : }
179 :
180 : void
181 48 : Set::deep_copy(const Set& m)
182 : {
183 48 : Stateful::deep_copy(m);
184 48 : if (m.sequence_cloned != nullptr)
185 0 : this->sequence_cloned.reset(m.sequence_cloned->clone());
186 : else
187 : {
188 48 : this->sequence_cloned.reset(m.sequence_extern->clone());
189 48 : this->sequence_extern = nullptr;
190 : }
191 :
192 48 : auto& lasts = this->get_sequence().get_lasts_tasks()[0];
193 :
194 : try
195 : {
196 50 : auto& p = (*this)("exec");
197 :
198 47 : size_t sid = 0;
199 94 : for (auto& last : lasts)
200 188 : for (auto& s : last->sockets)
201 : {
202 141 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
203 : {
204 94 : while (p.sockets[sid]->get_type() != runtime::socket_t::SOUT)
205 47 : sid++;
206 47 : p.sockets[sid++]->_bind(*s); // out to out socket binding = black magic
207 : }
208 : }
209 : }
210 1 : catch (tools::invalid_argument&)
211 : {
212 : /* this is a hack: do nothing, we went there because of trying to determine if the set is replicable */
213 1 : }
214 48 : }
215 :
216 : void
217 97 : Set::set_n_frames(const size_t n_frames)
218 : {
219 97 : const auto old_n_frames = this->get_n_frames();
220 97 : if (old_n_frames != n_frames)
221 : {
222 96 : auto& p = *this->tasks[0];
223 96 : auto& lasts = this->get_sequence().get_lasts_tasks()[0];
224 96 : size_t sid = 0;
225 192 : for (auto& last : lasts)
226 384 : for (auto& s : last->sockets)
227 : {
228 288 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
229 : {
230 192 : while (p.sockets[sid]->get_type() != runtime::socket_t::SOUT)
231 96 : sid++;
232 96 : p.sockets[sid++]->unbind(*s);
233 : }
234 : }
235 :
236 96 : Module::set_n_frames(n_frames);
237 :
238 96 : if (this->sequence_extern)
239 2 : this->sequence_extern->set_n_frames(n_frames);
240 : else
241 94 : this->sequence_cloned->set_n_frames(n_frames);
242 :
243 96 : sid = 0;
244 192 : for (auto& last : lasts)
245 384 : for (auto& s : last->sockets)
246 : {
247 288 : if (s->get_type() == runtime::socket_t::SOUT && s->get_name() != "status")
248 : {
249 192 : while (p.sockets[sid]->get_type() != runtime::socket_t::SOUT)
250 96 : sid++;
251 96 : p.sockets[sid++]->_bind(*s); // out to out socket binding = black magic
252 : }
253 : }
254 : }
255 97 : }
|