74template<
class Reducer, Expression Operand>
106 reducer.forward(source, result);
141 constexpr type
reduce(type dtype)
const {
146 if (shape.
rank() == 0)
147 throw Exception(
"Cannot reduce scalar tensors");
150 for (
size_t dim = 0; dim < shape.
rank(); ++dim) {
151 if (dim !=
static_cast<size_t>(
axis)) out.
expand(shape[dim]);
179 constexpr type
reduce(type dtype)
const {
184 if (shape.
rank() == 0)
185 throw Exception(
"Cannot reduce scalar tensors");
188 for (uint8_t dimension = 0; dimension < shape.
rank(); ++dimension) {
189 if (dimension !=
static_cast<uint8_t
>(
axis))
190 reduced.
expand(shape[dimension]);
221 constexpr type
reduce(type dtype)
const {
227 for (
size_t dim = 0; dim < shape.
rank(); ++dim) {
228 if (dim !=
static_cast<size_t>(
axis)) {
229 reduced.
expand(shape[dim]);
260 constexpr type
reduce(type dtype)
const {
261 assert(dtype == float32 | dtype == float64 &&
"Integral dtypes not supported.");
267 for (
size_t dim = 0; dim < shape.
rank(); ++dim) {
268 if (dim !=
static_cast<size_t>(
axis)) {
269 reduced.
expand(shape[dim]);
301template<Expression Source>
302constexpr auto argmax(Source&& source,
int axis = -1,
bool keepdim =
false) {
323template<Expression Source>
324constexpr auto argmin(Source&& source,
int axis = -1,
bool keepdim =
false) {
350template<Expression Source>
351constexpr auto sum(Source&& source,
int axis = -1,
bool keepdim =
false) {
354 std::forward<Source>(source)
378template<Expression Source>
379constexpr auto mean(Source&& source,
int axis = -1,
bool keepdim =
false) {
382 std::forward<Source>(source)
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
constexpr void expand(size_type size)
Expands the shape's last dimension with a given size.
Definition: shape.hpp:286
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Lazy reduction expression.
Definition: reductions.hpp:75
constexpr type dtype() const
Definition: reductions.hpp:87
std::ptrdiff_t offset() const
Definition: reductions.hpp:99
Tensor forward() const
Definition: reductions.hpp:103
constexpr Reduction(Reducer reducer, typename Trait< Operand >::Reference operand)
Definition: reductions.hpp:80
constexpr Shape const & shape() const
Definition: reductions.hpp:91
Reducer reducer
Definition: reductions.hpp:77
constexpr Strides const & strides() const
Definition: reductions.hpp:95
Trait< Operand >::Reference operand
Definition: reductions.hpp:78
Definition: comparisons.hpp:69
constexpr auto mean(Source &&source, int axis=-1, bool keepdim=false)
Creates a mean reduction.
Definition: reductions.hpp:379
constexpr auto argmax(Source &&source, int axis=-1, bool keepdim=false)
Creates an Argmax reduction.
Definition: reductions.hpp:302
constexpr auto argmin(Source &&source, int axis=-1, bool keepdim=false)
Creates an Argmin reduction.
Definition: reductions.hpp:324
constexpr auto sum(Source &&source, int axis=-1, bool keepdim=false)
Creates a sum reduction.
Definition: reductions.hpp:351
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
Definition: buffer.hpp:41
std::decay_t< T > Reference
Definition: traits.hpp:28
Finds the indices of maximum values along an axis.
Definition: reductions.hpp:137
bool keepdim
Definition: reductions.hpp:139
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:145
int axis
Definition: reductions.hpp:138
void forward(Tensor const &input, Tensor &output) const
constexpr type reduce(type dtype) const
Definition: reductions.hpp:141
Computes the mean along an axis.
Definition: reductions.hpp:256
bool keepdim
Definition: reductions.hpp:258
constexpr type reduce(type dtype) const
Definition: reductions.hpp:260
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:265
int axis
Definition: reductions.hpp:257
void forward(Tensor const &input, Tensor &output) const
Finds the indexes of minimum values along an axis.
Definition: reductions.hpp:175
bool keepdim
Definition: reductions.hpp:177
void forward(Tensor const &, Tensor &) const
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:183
int axis
Definition: reductions.hpp:176
constexpr type reduce(type dtype) const
Definition: reductions.hpp:179
Sums tensor values along an axis.
Definition: reductions.hpp:217
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:225
bool keepdim
Definition: reductions.hpp:219
void forward(Tensor const &input, Tensor &output) const
constexpr type reduce(type dtype) const
Definition: reductions.hpp:221
int axis
Definition: reductions.hpp:218