Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
reductions.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15
16// Copyright 2025 Eric Hermosis
17//
18// Licensed under the Apache License, Version 2.0 (the "License");
19// you may not use this file except in compliance with the License.
20// You may obtain a copy of the License at
21//
22// http://www.apache.org/licenses/LICENSE-2.0
23//
24// Unless required by applicable law or agreed to in writing, software
25// distributed under the License is distributed on an "AS IS" BASIS,
26// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27// See the License for the specific language governing permissions and
28// limitations under the License.
29
30#ifndef REDUCTIONS_HPP
31#define REDUCTIONS_HPP
32
50#include <array>
51#include <cassert>
52
53#include "concepts.hpp"
54#include "types.hpp"
55#include "traits.hpp"
56#include "shape.hpp"
57#include "tensor.hpp"
58#include "indexing.hpp"
59#include "exceptions.hpp"
60
61
62namespace tannic::expression {
63
64
74template<class Reducer, Expression Operand>
75class Reduction {
76public:
77 Reducer reducer;
79
83 , dtype_(reducer.reduce(operand.dtype()))
84 , shape_(reducer.reduce(operand.shape()))
85 , strides_(shape_) {}
86
87 constexpr type dtype() const {
88 return dtype_;
89 }
90
91 constexpr Shape const& shape() const {
92 return shape_;
93 }
94
95 constexpr Strides const& strides() const {
96 return strides_;
97 }
98
99 std::ptrdiff_t offset() const {
100 return 0;
101 }
102
103 Tensor forward() const {
104 Tensor source = operand.forward();
105 Tensor result(dtype(), shape(), strides(), offset());
106 reducer.forward(source, result);
107 return result;
108 }
109
110private:
111 type dtype_;
112 Shape shape_;
113 Strides strides_;
114};
115
136struct Argmax {
137 int axis;
138 bool keepdim;
139
140 constexpr type reduce(type dtype) const {
141 return int64;
142 }
143
144 constexpr Shape reduce(Shape const& shape) const {
145 if (shape.rank() == 0)
146 throw Exception("Cannot reduce scalar tensors");
147
148 Shape out;
149 for (size_t dim = 0; dim < shape.rank(); ++dim) {
150 if (dim != static_cast<size_t>(axis)) out.expand(shape[dim]);
151 else if (keepdim) out.expand(1);
152 }
153 return out;
154 }
155
156 void forward(Tensor const& input, Tensor& output) const;
157};
158
174struct Argmin {
175 int axis;
176 bool keepdim;
177
178 constexpr type reduce(type dtype) const {
179 return int64;
180 }
181
182 constexpr Shape reduce(Shape const& shape) const {
183 if (shape.rank() == 0)
184 throw Exception("Cannot reduce scalar tensors");
185
186 Shape reduced;
187 for (uint8_t dimension = 0; dimension < shape.rank(); ++dimension) {
188 if (dimension != static_cast<uint8_t>(axis))
189 reduced.expand(shape[dimension]);
190 }
191 return reduced;
192 }
193
194 void forward(Tensor const&, Tensor&) const;
195};
196
197
216struct Argsum {
217 int axis;
218 bool keepdim;
219
220 constexpr type reduce(type dtype) const {
221 return dtype;
222 }
223
224 constexpr Shape reduce(Shape const& shape) const {
225 Shape reduced;
226 for (size_t dim = 0; dim < shape.rank(); ++dim) {
227 if (dim != static_cast<size_t>(axis)) {
228 reduced.expand(shape[dim]);
229 } else if (keepdim) {
230 reduced.expand(1); // Keep reduced dim as size 1
231 }
232 }
233 return reduced;
234 }
235
236 void forward(Tensor const& input, Tensor& output) const;
237};
238
239
255struct Argmean {
256 int axis;
257 bool keepdim = false;
258
259 constexpr type reduce(type dtype) const {
260 assert(dtype == float32 | dtype == float64 && "Integral dtypes not supported.");
261 return dtype;
262 }
263
264 constexpr Shape reduce(Shape const& shape) const {
265 Shape reduced;
266 for (size_t dim = 0; dim < shape.rank(); ++dim) {
267 if (dim != static_cast<size_t>(axis)) {
268 reduced.expand(shape[dim]);
269 } else if (keepdim) {
270 reduced.expand(1);
271 }
272 }
273 return reduced;
274 }
275
276 void forward(Tensor const& input, Tensor& output) const;
277};
278
279
300template<Expression Source>
301constexpr auto argmax(Source&& source, int axis, bool keepdim = false) {
303 {indexing::normalize(axis, source.shape().rank()), keepdim}, std::forward<Source>(source)
304 };
305}
306
307
322template<Expression Source>
323constexpr auto argmin(Source&& source, int axis, bool keepdim = false) {
325 {indexing::normalize(axis, source.shape().rank()), keepdim}, std::forward<Source>(source)
326 };
327}
328
349template<Expression Source>
350constexpr auto sum(Source&& source, int axis, bool keepdim = false) {
352 {indexing::normalize(axis, source.shape().rank()), keepdim},
353 std::forward<Source>(source)
354 };
355}
356
377template<Expression Source>
378constexpr auto mean(Source&& source, int axis, bool keepdim = false) {
380 {indexing::normalize(axis, source.shape().rank()), keepdim},
381 std::forward<Source>(source)
382 };
383}
384
385} namespace tannic {
386
389using expression::sum;
390using expression::mean;
391
392} // namespace tannic
393
394#endif // REDUCTIONS_HPP
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
constexpr void expand(size_type size)
Expands the shape's last dimension with a given size.
Definition: shape.hpp:286
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Lazy reduction expression.
Definition: reductions.hpp:75
constexpr type dtype() const
Definition: reductions.hpp:87
std::ptrdiff_t offset() const
Definition: reductions.hpp:99
Tensor forward() const
Definition: reductions.hpp:103
constexpr Reduction(Reducer reducer, typename Trait< Operand >::Reference operand)
Definition: reductions.hpp:80
constexpr Shape const & shape() const
Definition: reductions.hpp:91
Reducer reducer
Definition: reductions.hpp:77
constexpr Strides const & strides() const
Definition: reductions.hpp:95
Trait< Operand >::Reference operand
Definition: reductions.hpp:78
Defines core C++20 concepts used throughout the Tannic Tensor Library.
Utilities for index normalization and slicing ranges in the Tannic Tensor Library.
Definition: comparisons.hpp:69
constexpr auto argmax(Source &&source, int axis, bool keepdim=false)
Creates an Argmax reduction.
Definition: reductions.hpp:301
constexpr auto sum(Source &&source, int axis, bool keepdim=false)
Creates a sum reduction.
Definition: reductions.hpp:350
constexpr auto argmin(Source &&source, int axis, bool keepdim=false)
Creates an Argmin reduction.
Definition: reductions.hpp:323
constexpr auto mean(Source &&source, int axis, bool keepdim=false)
Creates a mean reduction.
Definition: reductions.hpp:378
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
Definition: buffer.hpp:41
Defines the tannic::Shape class for representing tensor dimensions.
std::decay_t< T > Reference
Definition: traits.hpp:28
Finds the indices of maximum values along an axis.
Definition: reductions.hpp:136
bool keepdim
Definition: reductions.hpp:138
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:144
int axis
Definition: reductions.hpp:137
void forward(Tensor const &input, Tensor &output) const
constexpr type reduce(type dtype) const
Definition: reductions.hpp:140
Computes the mean along an axis.
Definition: reductions.hpp:255
bool keepdim
Definition: reductions.hpp:257
constexpr type reduce(type dtype) const
Definition: reductions.hpp:259
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:264
int axis
Definition: reductions.hpp:256
void forward(Tensor const &input, Tensor &output) const
Finds the indexes of minimum values along an axis.
Definition: reductions.hpp:174
bool keepdim
Definition: reductions.hpp:176
void forward(Tensor const &, Tensor &) const
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:182
int axis
Definition: reductions.hpp:175
constexpr type reduce(type dtype) const
Definition: reductions.hpp:178
Sums tensor values along an axis.
Definition: reductions.hpp:216
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:224
bool keepdim
Definition: reductions.hpp:218
void forward(Tensor const &input, Tensor &output) const
constexpr type reduce(type dtype) const
Definition: reductions.hpp:220
int axis
Definition: reductions.hpp:217
Core multidimensional tensor class for the Tannic Tensor Library.
Core type system for the Tannic Tensor Library.