Line data Source code
1 : #include <algorithm>
2 : #include <sstream>
3 : #include <string>
4 :
5 : #include "Module/Stateful/Switcher/Switcher.hpp"
6 : #include "Tools/Exception/exception.hpp"
7 : #include "Tools/compute_bytes.h"
8 :
9 : namespace spu
10 : {
11 : namespace module
12 : {
13 :
14 : runtime::Task&
15 : Switcher::operator[](const swi::tsk t)
16 : {
17 : return Module::operator[]((size_t)t);
18 : }
19 :
20 : const runtime::Task&
21 232 : Switcher::operator[](const swi::tsk t) const
22 : {
23 232 : return Module::operator[]((size_t)t);
24 : }
25 :
26 : runtime::Socket&
27 : Switcher::operator[](const std::string& tsk_sck)
28 : {
29 : return Module::operator[](tsk_sck);
30 : }
31 :
32 : Switcher::Switcher(const size_t n_data_sockets,
33 : const size_t n_elmts_commute,
34 : const std::type_index datatype_commute,
35 : const size_t n_elmts_select,
36 : const std::type_index datatype_select)
37 : : Stateful()
38 : , n_data_sockets(n_data_sockets)
39 : , n_elmts_commute(n_elmts_commute)
40 : , n_elmts_select(n_elmts_select)
41 : , n_bytes_commute(tools::compute_bytes(n_elmts_commute, datatype_commute))
42 : , n_bytes_select(tools::compute_bytes(n_elmts_select, datatype_select))
43 : , datatype_commute(datatype_commute)
44 : , datatype_select(datatype_select)
45 : , path(n_data_sockets - 1)
46 : , no_copy_commute(false)
47 : , no_copy_select(false)
48 : {
49 : const std::string name = "Switcher";
50 : this->set_name(name);
51 : this->set_short_name(name);
52 : this->set_single_wave(true);
53 :
54 : if (n_data_sockets == 0)
55 : {
56 : std::stringstream message;
57 : message << "'n_data_sockets' has to be greater than 0 ('n_data_sockets' = " << n_data_sockets << ").";
58 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
59 : }
60 :
61 : if (n_elmts_commute == 0)
62 : {
63 : std::stringstream message;
64 : message << "'n_elmts_commute' has to be greater than 0 ('n_elmts_commute' = " << n_elmts_commute << ").";
65 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
66 : }
67 :
68 : if (n_elmts_select == 0)
69 : {
70 : std::stringstream message;
71 : message << "'n_elmts_select' has to be greater than 0 ('n_elmts_select' = " << n_elmts_select << ").";
72 : throw tools::invalid_argument(__FILE__, __LINE__, __func__, message.str());
73 : }
74 :
75 : auto& p1 = this->create_task("commute");
76 : const auto p1s_in_data = this->create_socket_in(p1, "in_data", n_elmts_commute, datatype_commute);
77 : const auto p1s_in_ctrl = this->create_socket_in(p1, "in_ctrl", 1, typeid(int8_t));
78 :
79 : std::vector<size_t> p1s_out_data;
80 : for (size_t s = 0; s < this->get_n_data_sockets(); s++)
81 : p1s_out_data.push_back(
82 : this->create_socket_out(p1, "out_data" + std::to_string(s), n_elmts_commute, datatype_commute));
83 :
84 : this->create_codelet(
85 : p1,
86 : [p1s_in_data, p1s_in_ctrl, p1s_out_data](Module& m, runtime::Task& t, const size_t frame_id) -> int
87 : {
88 : auto& swi = static_cast<Switcher&>(m);
89 :
90 : const auto ctrl_socket_in = t[p1s_in_ctrl].get_dataptr<const int8_t>();
91 : swi.set_path((size_t)ctrl_socket_in[0]);
92 : const size_t path = swi.get_path();
93 :
94 : if (!swi.is_no_copy_commute())
95 : {
96 : const auto data_socket_in = t[p1s_in_data].get_dataptr<const int8_t>();
97 : auto data_socket_out = t[p1s_out_data[path]].get_dataptr<int8_t>();
98 :
99 : std::copy(
100 : data_socket_in, data_socket_in + swi.get_n_frames() * swi.get_n_bytes_commute(), data_socket_out);
101 : }
102 :
103 : return (int)path;
104 : });
105 :
106 : auto& p2 = this->create_task("select");
107 : std::vector<size_t> p2s_in_data;
108 : for (size_t s = 0; s < this->get_n_data_sockets(); s++)
109 : p2s_in_data.push_back(
110 : this->create_socket_in(p2, "in_data" + std::to_string(s), n_elmts_select, datatype_select));
111 : auto p2s_out_data = this->create_socket_out(p2, "out_data", n_elmts_select, datatype_select);
112 :
113 : this->create_codelet(p2,
114 : [p2s_in_data, p2s_out_data](Module& m, runtime::Task& t, const size_t frame_id) -> int
115 : {
116 : auto& swi = static_cast<Switcher&>(m);
117 :
118 : if (!swi.is_no_copy_select())
119 : {
120 : const size_t path = swi.get_path();
121 :
122 : const auto data_socket_in = t[p2s_in_data[path]].get_dataptr<const int8_t>();
123 : auto data_socket_out = t[p2s_out_data].get_dataptr<int8_t>();
124 :
125 : std::copy(data_socket_in,
126 : data_socket_in + swi.get_n_frames() * swi.get_n_bytes_select(),
127 : data_socket_out);
128 : }
129 :
130 : return runtime::status_t::SUCCESS;
131 : });
132 : }
133 :
134 : Switcher::Switcher(const size_t n_data_sockets, const size_t n_elmts, const std::type_index datatype)
135 : : Switcher(n_data_sockets, n_elmts, datatype, n_elmts, datatype)
136 : {
137 : }
138 :
139 : void
140 2870 : Switcher::set_no_copy_commute(const bool no_copy_commute)
141 : {
142 2870 : this->no_copy_commute = no_copy_commute;
143 2870 : }
144 :
145 : void
146 2870 : Switcher::set_no_copy_select(const bool no_copy_select)
147 : {
148 2870 : this->no_copy_select = no_copy_select;
149 2870 : }
150 :
151 : bool
152 : Switcher::is_no_copy_commute() const
153 : {
154 : return this->no_copy_commute;
155 : }
156 :
157 : bool
158 : Switcher::is_no_copy_select() const
159 : {
160 : return this->no_copy_select;
161 : }
162 :
163 : size_t
164 0 : Switcher::get_n_data_sockets() const
165 : {
166 0 : return this->n_data_sockets;
167 : }
168 :
169 : size_t
170 : Switcher::get_n_elmts_commute() const
171 : {
172 : return this->n_elmts_commute;
173 : }
174 :
175 : size_t
176 : Switcher::get_n_elmts_select() const
177 : {
178 : return this->n_elmts_select;
179 : }
180 :
181 : size_t
182 : Switcher::get_n_bytes_commute() const
183 : {
184 : return this->n_bytes_commute;
185 : }
186 :
187 : size_t
188 : Switcher::get_n_bytes_select() const
189 : {
190 : return this->n_bytes_select;
191 : }
192 :
193 : std::type_index
194 : Switcher::get_datatype_commute() const
195 : {
196 : return this->datatype_commute;
197 : }
198 :
199 : std::type_index
200 : Switcher::get_datatype_select() const
201 : {
202 : return this->datatype_select;
203 : }
204 :
205 : size_t
206 0 : Switcher::get_path() const
207 : {
208 0 : return this->path;
209 : }
210 :
211 : void
212 : Switcher::set_path(const size_t path)
213 : {
214 : this->path = path % this->get_n_data_sockets();
215 : }
216 :
217 : }
218 : }
|