Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
operations.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15
16#ifndef OPERATIONS_HPP
17#define OPERATIONS_HPP
18
39#include <utility>
40#include <type_traits>
41
42#include "concepts.hpp"
43#include "types.hpp"
44#include "shape.hpp"
45#include "strides.hpp"
46#include "traits.hpp"
47
48namespace tannic {
49
50class Tensor;
51
52} namespace tannic::operation {
53
63template<class Operation, Expression Operand>
64class Unary {
65public:
66 Operation operation;
68
77 {}
78
87 constexpr type dtype() const {
88 return operand.dtype();
89 }
90
98 constexpr Shape const& shape() const {
99 return operand.shape();
100 }
101
102
111 constexpr Strides const& strides() const {
112 return operand.strides();
113 }
114
115
121 std::ptrdiff_t offset() const {
122 return operand.offset();
123 }
124
129 Tensor forward() const;
130};
131
143template<Operator Operation, Expression Operand, Expression Cooperand>
144class Binary {
145public:
146 Operation operation;
149
165 , dtype_(operation.promote(operand.dtype(), cooperand.dtype()))
166 , shape_(operation.broadcast(operand.shape(), cooperand.shape()))
167 , strides_(shape_)
168 {}
169
170
179 constexpr type dtype() const {
180 return dtype_;
181 }
182
191 constexpr Shape const& shape() const {
192 return shape_;
193 }
194
203 constexpr Strides const& strides() const {
204 return strides_;
205 }
206
215 std::ptrdiff_t offset() const {
216 return 0;
217 }
218
219 Tensor forward() const;
220
221private:
222 type dtype_;
223 Shape shape_;
224 Strides strides_;
225};
226
227
239static constexpr inline type promote(type first, type second) {
240 return static_cast<int>(first) > static_cast<int>(second) ? first : second;
241}
242
264static constexpr Shape broadcast(Shape const& first, Shape const& second) {
265 auto first_rank = first.rank();
266 auto second_rank = second.rank();
267 auto rank = std::max(first_rank, second_rank);
268 Shape shape;
269 for (auto dimension = 0; dimension < rank; ++dimension) {
270 auto first_dimension = (dimension < rank - first_rank) ? 1 : first[dimension - (rank - first_rank)];
271 auto second_dimension = (dimension < rank - second_rank) ? 1 : second[dimension - (rank - second_rank)];
272
273 if (!(first_dimension == second_dimension || first_dimension == 1 || second_dimension == 1)) {
274 throw Exception("Shapes are not broadcast-compatible.");
275 }
276 shape.expand(std::max(first_dimension, second_dimension));
277 }
278 return shape;
279}
280
281
305struct Negation {
306 void forward(Tensor const&, Tensor&) const;
307};
308
309
333struct Addition {
334 void forward(Tensor const&, Tensor const&, Tensor&) const;
335
336 constexpr static type promote(type first, type second) {
337 return operation::promote(first, second);
338 }
339
340 constexpr static Shape broadcast(Shape const& first, Shape const& second) {
341 return operation::broadcast(first, second);
342 }
343};
344
345
370 void forward(Tensor const&, Tensor const&, Tensor&) const;
371
372 constexpr static type promote(type first, type second) {
373 return operation::promote(first, second);
374 }
375
376 constexpr static Shape broadcast(Shape const& first, Shape const& second) {
377 return operation::broadcast(first, second);
378 }
379};
380
381
406 void forward(Tensor const&, Tensor const&, Tensor&) const;
407
408 constexpr static type promote(type first, type second) {
409 return operation::promote(first, second);
410 }
411
412 constexpr static Shape broadcast(Shape const& first, Shape const& second) {
413 return operation::broadcast(first, second);
414 }
415};
416
444 void forward(Tensor const&, Tensor const&, Tensor&) const;
445
446 constexpr static type promote(type first, type second) {
447 return operation::promote(first, second);
448 }
449
450 constexpr static Shape broadcast(Shape const& first, Shape const& second) {
451 return operation::broadcast(first, second);
452 }
453};
454
466template<Expression Operand>
467constexpr auto operator-(Operand&& operand) {
468 return Unary<Negation, Operand>{{}, std::forward<Operand>(operand)};
469}
470
487template<Expression Augend, Expression Addend>
488constexpr auto operator+(Augend&& augend, Addend&& addend) {
489 return Binary<Addition, Augend, Addend>{{}, std::forward<Augend>(augend), std::forward<Addend>(addend)};
490}
491
492
508template<Expression Subtrahend, Expression Minuend>
509constexpr auto operator-(Subtrahend&& subtrahend, Minuend&& minuend) {
510 return Binary<Subtraction, Subtrahend, Minuend>{{}, std::forward<Subtrahend>(subtrahend), std::forward<Minuend>(minuend)};
511}
512
513
529template<Expression Multiplicand, Expression Multiplier>
530constexpr auto operator*(Multiplicand&& multiplicand, Multiplier&& multiplier) {
531 return Binary<Multiplication, Multiplicand, Multiplier>{{}, std::forward<Multiplicand>(multiplicand), std::forward<Multiplier>(multiplier)};
532}
533
534
553template<Expression Base, Expression Exponent>
554constexpr auto operator^(Base&& base, Exponent&& exponent) {
555 return Binary<Exponentiation, Base, Exponent>{{}, std::forward<Base>(base), std::forward<Exponent>(exponent)};
556}
557
558} namespace tannic {
559
560using operation::operator-;
561using operation::operator+;
562using operation::operator*;
563using operation::operator^;
564
565} // namespace tannic
566
567#endif // OPERATIONS_HPP
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Expression template for a binary tensor operation.
Definition: operations.hpp:144
Trait< Operand >::Reference operand
Definition: operations.hpp:147
constexpr Strides const & strides() const
Returns the output strides for the result tensor.
Definition: operations.hpp:203
constexpr Shape const & shape() const
Returns the shape of the result.
Definition: operations.hpp:191
Operation operation
Definition: operations.hpp:146
Trait< Cooperand >::Reference cooperand
Definition: operations.hpp:148
constexpr Binary(Operation operation, Trait< Operand >::Reference operand, Trait< Cooperand >::Reference cooperand)
Constructs a Binary expression.
Definition: operations.hpp:161
std::ptrdiff_t offset() const
Returns the offset of the expression.
Definition: operations.hpp:215
Tensor forward() const
Definition: tensor.hpp:1158
constexpr type dtype() const
Returns the promoted data type of the result.
Definition: operations.hpp:179
Expression template for a unary tensor aritmetic operation.
Definition: operations.hpp:64
Trait< Operand >::Reference operand
Definition: operations.hpp:67
constexpr type dtype() const
Returns the data type of the result.
Definition: operations.hpp:87
constexpr Strides const & strides() const
Returns the strides of the result.
Definition: operations.hpp:111
Tensor forward() const
Evaluates the unary expression and returns a Tensor.
Definition: tensor.hpp:1151
std::ptrdiff_t offset() const
Returns the offset of the expression.
Definition: operations.hpp:121
Operation operation
Definition: operations.hpp:66
constexpr Unary(Operation operation, Trait< Operand >::Reference operand)
Constructs a Unary expression.
Definition: operations.hpp:74
constexpr Shape const & shape() const
Returns the shape of the result.
Definition: operations.hpp:98
Definition: operations.hpp:52
constexpr auto operator^(Base &&base, Exponent &&exponent)
Element-wise exponentiation of two tensor expressions.
Definition: operations.hpp:554
constexpr auto operator*(Multiplicand &&multiplicand, Multiplier &&multiplier)
Element-wise multiplication of two tensor expressions.
Definition: operations.hpp:530
constexpr auto operator+(Augend &&augend, Addend &&addend)
Element-wise addition of two tensor expressions.
Definition: operations.hpp:488
constexpr auto operator-(Operand &&operand)
Element-wise negation of a tensor expression.
Definition: operations.hpp:467
Definition: buffer.hpp:41
std::decay_t< T > Reference
Definition: traits.hpp:28
Binary element-wise addition of two tensor expressions.
Definition: operations.hpp:333
static constexpr Shape broadcast(Shape const &first, Shape const &second)
Definition: operations.hpp:340
static constexpr type promote(type first, type second)
Definition: operations.hpp:336
void forward(Tensor const &, Tensor const &, Tensor &) const
Binary element-wise exponentiation of two tensor expressions.
Definition: operations.hpp:443
void forward(Tensor const &, Tensor const &, Tensor &) const
static constexpr Shape broadcast(Shape const &first, Shape const &second)
Definition: operations.hpp:450
static constexpr type promote(type first, type second)
Definition: operations.hpp:446
Binary element-wise multiplication of two tensor expressions.
Definition: operations.hpp:369
static constexpr type promote(type first, type second)
Definition: operations.hpp:372
static constexpr Shape broadcast(Shape const &first, Shape const &second)
Definition: operations.hpp:376
void forward(Tensor const &, Tensor const &, Tensor &) const
Unary element-wise negation of a tensor expression.
Definition: operations.hpp:305
void forward(Tensor const &, Tensor &) const
Binary element-wise subtraction of two tensor expressions.
Definition: operations.hpp:405
static constexpr type promote(type first, type second)
Definition: operations.hpp:408
static constexpr Shape broadcast(Shape const &first, Shape const &second)
Definition: operations.hpp:412
void forward(Tensor const &, Tensor const &, Tensor &) const