16#ifndef CONVOLUTIONS_HPP
17#define CONVOLUTIONS_HPP
72 constexpr type
promote(type signal, type kernel)
const {
74 throw Exception(
"Dtypes must match in convolutions.");
97 if (signal.
rank() != 3 || kernel.
rank() != 3)
98 throw Exception(
"Only rank 3 tensors supported for Conv1D.");
100 std::size_t N = signal[0];
101 std::size_t C_in = signal[1];
102 std::size_t L_in = signal[2];
104 std::size_t C_out = kernel[0];
105 std::size_t K_in = kernel[1];
106 std::size_t K_len = kernel[2];
109 throw Exception(
"Input channels must match kernel channels.");
111 std::size_t L_out = (L_in + 2 *
padding[0] - K_len) /
strides[0] + 1;
112 return Shape{N, C_out, L_out};
148 constexpr type
promote(type signal, type kernel)
const {
149 if (signal != kernel)
150 throw Exception(
"Dtypes must match in convolutions.");
175 if (signal.
rank() != 4 | kernel.
rank() != 4)
176 throw Exception(
"Only rank 4 tensors supported.");
178 std::size_t N = signal[0];
179 std::size_t C_in = signal[1];
180 std::size_t H_in = signal[2];
181 std::size_t W_in = signal[3];
183 std::size_t C_out = kernel[0];
184 std::size_t K_in = kernel[1];
185 std::size_t K_h = kernel[2];
186 std::size_t K_w = kernel[3];
189 throw Exception(
"Input channels must match kernel channels.");
191 std::size_t H_out = (H_in + 2 *
padding[0] - K_h) /
strides[0] + 1;
192 std::size_t W_out = (W_in + 2 *
padding[1] - K_w) /
strides[1] + 1;
194 return Shape{N, C_out, H_out, W_out};
219template<Expression Signal, Expression Kernel>
220constexpr auto convolve1D(Signal&& signal, Kernel&& kernel, std::size_t stride, std::size_t padding) {
222 throw Exception(
"Stride must be non-zero for Conv1D.");
225 {{stride}, {padding}},
226 std::forward<Signal>(signal),
227 std::forward<Kernel>(kernel)
243template<Expression Signal, Expression Kernel>
245 Signal&& signal, Kernel&& kernel,
246 std::array<std::size_t, 1> strides,
247 std::array<std::size_t, 1> padding
250 throw Exception(
"Stride must be non-zero for Conv1D.");
254 std::forward<Signal>(signal),
255 std::forward<Kernel>(kernel)
271template<Expression Signal, Expression Kernel>
272constexpr auto convolve2D(Signal&& signal, Kernel&& kernel, std::size_t stride, std::size_t padding) {
274 throw Exception(
"Stride must be non-zero for Conv2D.");
277 {{stride, stride}, {padding, padding}},
278 std::forward<Signal>(signal),
279 std::forward<Kernel>(kernel)
295template<Expression Signal, Expression Kernel>
297 Signal&& signal, Kernel&& kernel,
298 std::array<std::size_t, 2> strides,
299 std::array<std::size_t, 2> padding
301 if (strides[0] == 0 || strides[1] == 0)
302 throw Exception(
"Stride values must be non-zero for Conv2D.");
306 std::forward<Signal>(signal),
307 std::forward<Kernel>(kernel)
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Definition: buffer.hpp:41