Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
reductions.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15
16// Copyright 2025 Eric Hermosis
17//
18// Licensed under the Apache License, Version 2.0 (the "License");
19// you may not use this file except in compliance with the License.
20// You may obtain a copy of the License at
21//
22// http://www.apache.org/licenses/LICENSE-2.0
23//
24// Unless required by applicable law or agreed to in writing, software
25// distributed under the License is distributed on an "AS IS" BASIS,
26// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27// See the License for the specific language governing permissions and
28// limitations under the License.
29
30#ifndef REDUCTIONS_HPP
31#define REDUCTIONS_HPP
32
50#include <array>
51#include <cassert>
52
53#include "concepts.hpp"
54#include "types.hpp"
55#include "traits.hpp"
56#include "shape.hpp"
57#include "tensor.hpp"
58#include "indexing.hpp"
59#include "exceptions.hpp"
60
61
62namespace tannic::expression {
63
64
74template<class Reducer, Expression Operand>
75class Reduction {
76public:
77 Reducer reducer;
79
83 , dtype_(reducer.reduce(operand.dtype()))
84 , shape_(reducer.reduce(operand.shape()))
85 , strides_(shape_) {}
86
87 constexpr type dtype() const {
88 return dtype_;
89 }
90
91 constexpr Shape const& shape() const {
92 return shape_;
93 }
94
95 constexpr Strides const& strides() const {
96 return strides_;
97 }
98
99 std::ptrdiff_t offset() const {
100 return 0;
101 }
102
103 Tensor forward() const {
104 Tensor source = operand.forward();
105 Tensor result(dtype(), shape(), strides(), offset());
106 reducer.forward(source, result);
107 return result;
108 }
109
110private:
111 type dtype_;
112 Shape shape_;
113 Strides strides_;
114};
115
116
137struct Argmax {
138 int axis;
139 bool keepdim;
140
141 constexpr type reduce(type dtype) const {
142 return int64;
143 }
144
145 constexpr Shape reduce(Shape const& shape) const {
146 if (shape.rank() == 0)
147 throw Exception("Cannot reduce scalar tensors");
148
149 Shape out;
150 for (size_t dim = 0; dim < shape.rank(); ++dim) {
151 if (dim != static_cast<size_t>(axis)) out.expand(shape[dim]);
152 else if (keepdim) out.expand(1);
153 }
154 return out;
155 }
156
157 void forward(Tensor const& input, Tensor& output) const;
158};
159
175struct Argmin {
176 int axis;
177 bool keepdim;
178
179 constexpr type reduce(type dtype) const {
180 return int64;
181 }
182
183 constexpr Shape reduce(Shape const& shape) const {
184 if (shape.rank() == 0)
185 throw Exception("Cannot reduce scalar tensors");
186
187 Shape reduced;
188 for (uint8_t dimension = 0; dimension < shape.rank(); ++dimension) {
189 if (dimension != static_cast<uint8_t>(axis))
190 reduced.expand(shape[dimension]);
191 }
192 return reduced;
193 }
194
195 void forward(Tensor const&, Tensor&) const;
196};
197
198
217struct Argsum {
218 int axis;
219 bool keepdim;
220
221 constexpr type reduce(type dtype) const {
222 return dtype;
223 }
224
225 constexpr Shape reduce(Shape const& shape) const {
226 Shape reduced;
227 for (size_t dim = 0; dim < shape.rank(); ++dim) {
228 if (dim != static_cast<size_t>(axis)) {
229 reduced.expand(shape[dim]);
230 } else if (keepdim) {
231 reduced.expand(1); // Keep reduced dim as size 1
232 }
233 }
234 return reduced;
235 }
236
237 void forward(Tensor const& input, Tensor& output) const;
238};
239
240
256struct Argmean {
257 int axis;
258 bool keepdim = false;
259
260 constexpr type reduce(type dtype) const {
261 assert(dtype == float32 | dtype == float64 && "Integral dtypes not supported.");
262 return dtype;
263 }
264
265 constexpr Shape reduce(Shape const& shape) const {
266 Shape reduced;
267 for (size_t dim = 0; dim < shape.rank(); ++dim) {
268 if (dim != static_cast<size_t>(axis)) {
269 reduced.expand(shape[dim]);
270 } else if (keepdim) {
271 reduced.expand(1);
272 }
273 }
274 return reduced;
275 }
276
277 void forward(Tensor const& input, Tensor& output) const;
278};
279
280
301template<Expression Source>
302constexpr auto argmax(Source&& source, int axis = -1, bool keepdim = false) {
304 {indexing::normalize(axis, source.shape().rank()), keepdim}, std::forward<Source>(source)
305 };
306}
307
308
323template<Expression Source>
324constexpr auto argmin(Source&& source, int axis = -1, bool keepdim = false) {
326 {indexing::normalize(axis, source.shape().rank()), keepdim}, std::forward<Source>(source)
327 };
328}
329
350template<Expression Source>
351constexpr auto sum(Source&& source, int axis = -1, bool keepdim = false) {
353 {indexing::normalize(axis, source.shape().rank()), keepdim},
354 std::forward<Source>(source)
355 };
356}
357
378template<Expression Source>
379constexpr auto mean(Source&& source, int axis = -1, bool keepdim = false) {
381 {indexing::normalize(axis, source.shape().rank()), keepdim},
382 std::forward<Source>(source)
383 };
384}
385
386} namespace tannic {
387
390using expression::sum;
391using expression::mean;
392
393} // namespace tannic
394
395#endif // REDUCTIONS_HPP
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
constexpr void expand(size_type size)
Expands the shape's last dimension with a given size.
Definition: shape.hpp:286
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Lazy reduction expression.
Definition: reductions.hpp:75
constexpr type dtype() const
Definition: reductions.hpp:87
std::ptrdiff_t offset() const
Definition: reductions.hpp:99
Tensor forward() const
Definition: reductions.hpp:103
constexpr Reduction(Reducer reducer, typename Trait< Operand >::Reference operand)
Definition: reductions.hpp:80
constexpr Shape const & shape() const
Definition: reductions.hpp:91
Reducer reducer
Definition: reductions.hpp:77
constexpr Strides const & strides() const
Definition: reductions.hpp:95
Trait< Operand >::Reference operand
Definition: reductions.hpp:78
Definition: comparisons.hpp:69
constexpr auto mean(Source &&source, int axis=-1, bool keepdim=false)
Creates a mean reduction.
Definition: reductions.hpp:379
constexpr auto argmax(Source &&source, int axis=-1, bool keepdim=false)
Creates an Argmax reduction.
Definition: reductions.hpp:302
constexpr auto argmin(Source &&source, int axis=-1, bool keepdim=false)
Creates an Argmin reduction.
Definition: reductions.hpp:324
constexpr auto sum(Source &&source, int axis=-1, bool keepdim=false)
Creates a sum reduction.
Definition: reductions.hpp:351
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
Definition: buffer.hpp:41
std::decay_t< T > Reference
Definition: traits.hpp:28
Finds the indices of maximum values along an axis.
Definition: reductions.hpp:137
bool keepdim
Definition: reductions.hpp:139
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:145
int axis
Definition: reductions.hpp:138
void forward(Tensor const &input, Tensor &output) const
constexpr type reduce(type dtype) const
Definition: reductions.hpp:141
Computes the mean along an axis.
Definition: reductions.hpp:256
bool keepdim
Definition: reductions.hpp:258
constexpr type reduce(type dtype) const
Definition: reductions.hpp:260
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:265
int axis
Definition: reductions.hpp:257
void forward(Tensor const &input, Tensor &output) const
Finds the indexes of minimum values along an axis.
Definition: reductions.hpp:175
bool keepdim
Definition: reductions.hpp:177
void forward(Tensor const &, Tensor &) const
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:183
int axis
Definition: reductions.hpp:176
constexpr type reduce(type dtype) const
Definition: reductions.hpp:179
Sums tensor values along an axis.
Definition: reductions.hpp:217
constexpr Shape reduce(Shape const &shape) const
Definition: reductions.hpp:225
bool keepdim
Definition: reductions.hpp:219
void forward(Tensor const &input, Tensor &output) const
constexpr type reduce(type dtype) const
Definition: reductions.hpp:221
int axis
Definition: reductions.hpp:218