Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
transformations.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15
16#ifndef TRANSFORMATIONS_HPP
17#define TRANSFORMATIONS_HPP
18
30#include <tuple>
31#include <array>
32#include <vector>
33#include <cassert>
34
35#include "concepts.hpp"
36#include "types.hpp"
37#include "traits.hpp"
38#include "shape.hpp"
39#include "tensor.hpp"
40#include "exceptions.hpp"
41
42namespace tannic {
43
44class Tensor;
45
46namespace transformation {
47
57template<class Operation, Expression ... Operands>
59public:
60 Operation operation;
61 std::tuple<typename Trait<Operands>::Reference...> operands;
62
70 , operands(operands...)
71 , dtype_(operation.promote(operands.dtype()...))
72 , shape_(operation.transform(operands.shape()...))
73 , strides_(shape_)
74 {}
75
80 constexpr type dtype() const {
81 return dtype_;
82 }
83
88 constexpr Shape const& shape() const {
89 return shape_;
90 }
91
96 constexpr Strides const& strides() const {
97 return strides_;
98 }
99
104 std::ptrdiff_t offset() const {
105 return 0;
106 }
107
112 Tensor forward() const {
113 Tensor result(dtype_, shape_);
114 std::apply([&](const auto&... arguments) {
115 return operation.forward(arguments.forward()..., result);
116 }, operands);
117 return result;
118 }
119
120private:
121 type dtype_;
122 Shape shape_;
123 Strides strides_;
124};
125
132static constexpr auto index(type first, type second) {
133 return static_cast<int>(first) + static_cast<int>(second) * static_cast<int>(TYPES);
134}
135
144struct Composition {
145 double scale = 1.0;
146
155 static constexpr auto promotions = []() {
156 std::array<type, index(TYPES, TYPES)> table{};
157 table.fill(unknown);
158 // Integer promotions
159 table[index(int8, int8)] = int32;
160 table[index(int8, int16)] = int32;
161 table[index(int8, int32)] = int32;
162 table[index(int8, int64)] = int64;
163
164 table[index(int16, int8)] = int32;
165 table[index(int16, int16)] = int32;
166 table[index(int16, int32)] = int32;
167 table[index(int16, int64)] = int64;
168
169 table[index(int32, int8)] = int32;
170 table[index(int32, int16)] = int32;
171 table[index(int32, int32)] = int64;
172 table[index(int32, int64)] = int64;
173
174 table[index(int64, int8)] = int64;
175 table[index(int64, int16)] = int64;
176 table[index(int64, int32)] = int64;
177 table[index(int64, int64)] = int64;
178
179 // Mixed integer/float promotions
180 table[index(int32, float32)] = float32;
181 table[index(float32, int32)] = float32;
182 table[index(int32, float64)] = float64;
183 table[index(float64, int32)] = float64;
184
185 // Float promotions
186 table[index(float32, float32)] = float32;
187 table[index(float32, float64)] = float64;
188 table[index(float64, float32)] = float64;
189 table[index(float64, float64)] = float64;
190 return table;
191 }();
192
216 constexpr static type promote(type inner, type outer) {
217 type dtype = promotions[index(inner, outer)];
218 if (dtype == unknown)
219 throw Exception("Unsuported dtypes");
220 return dtype;
221 }
222
237 static constexpr Shape transform(Shape const& first, Shape const& second) {
238 auto first_rank = first.rank();
239 auto second_rank = second.rank();
240
241 if (first_rank == 1 && second_rank == 1) {
242 if (first_rank != 1 | second_rank != 1)
243 throw Exception("dimensions must match for dot product");
244 return Shape{}; // Scalar result
245 }
246
247 if (first_rank == 1 && second_rank == 2) {
248 if (first[0] != second[0])
249 throw Exception("Matrix inner dimensions do not match");
250 if (first[0] != second[0])
251 return Shape{second[0]}; // Vector result
252 }
253
254 // Vector-matrix multiplication
255 if (first_rank == 2 && second_rank == 1) {
256 if (first[1] != second[0])
257 throw Exception("Matrix inner dimensions do not match");
258 return Shape{first[0]}; // Vector result
259 }
260
261 if (first_rank < 2 | second_rank < 2)
262 throw Exception("Inputs must have rank >= 2");
263
264 // Handle batch dimensions
265 Shape first_batches(first.begin(), first.end() - 2);
266 Shape second_batches(second.begin(), second.end() - 2);
267 Shape batches = operation::broadcast(first_batches, second_batches);
268
269 // Check matrix inner dimensions
270 auto K1 = *(first.end() - 1);
271 auto K2 = *(second.end() - 2);
272 assert(K1 == K2 && "Inner dimensions must match for matmul");
273
274 // Get output matrix dimensions
275 auto M = *(first.end() - 2);
276 auto N = *(second.end() - 1);
277
278 // Combine batch and matrix dimensions
279 std::vector<Shape::size_type> result(batches.begin(), batches.end());
280 result.push_back(M);
281 result.push_back(N);
282 return Shape(result);
283 }
284
285 void forward(Tensor const& outer, Tensor const& inner, Tensor& result) const;
286};
287
295struct Outer {
304 static constexpr type promote(type first, type second) {
305 return static_cast<int>(first) > static_cast<int>(second) ? first : second;
306 }
307
322 static constexpr Shape transform(Shape const& first, Shape const& second) {
323 if(first.rank() != 1 | second.rank() != 1)
324 throw Exception("Outer product of tensors with rank more than 1 not supported");
325 return Shape(first[0], second[0]);
326 }
327
328 void forward(Tensor const& first, Tensor const& second, Tensor& result) const;
329};
330
338 int axis;
339
348 constexpr type promote(type dtype) const {
349 return dtype;
350 }
351
360 constexpr Shape transform(Shape shape) const {
361 shape[indexing::normalize(axis, shape.rank())] *= repeats;
362 return shape;
363 }
364
365 void forward(Tensor const&, Tensor&) const;
366};
367
368
375 int axis;
385 constexpr auto promote(type first, type second) const {
386 if(first != second)
387 throw Exception("Cannot concatenate tensors of different dtypes");
388 return first;
389 }
390
401 constexpr Shape transform(Shape const& first, Shape const& second) {
402 assert(first.rank() == second.rank() && "Ranks must match for concatenation");
403 Shape result = first;
404
405 for (int dimension = 0; dimension < first.rank(); ++dimension) {
406 if (dimension == axis) {
407 result[dimension] = first[dimension] + second[dimension];
408 } else {
409 if (first[dimension] != second[dimension])
410 throw Exception("All dimensions except concat axis must match");
411 }
412 }
413
414 return result;
415 }
416
417 void forward(Tensor const&, Tensor const&, Tensor&) const;
418};
419
426struct Repack {
432 constexpr type promote(type dtype) const {
433 return dtype;
434 }
435
441 constexpr Shape transform(Shape const& shape) const {
442 return shape;
443 }
444
451 void forward(Tensor const& source, Tensor& result) const;
452};
453
462template<Expression Outer, Expression Inner>
463constexpr auto composition(Outer&& outer, Inner&& inner, double scale) {
465 {scale},
466 std::forward<Outer>(outer),
467 std::forward<Inner>(inner)
468 };
469}
470
479template<Expression First, Expression Second>
480constexpr auto outer(First&& first, Second&& second) {
482 {},
483 std::forward<First>(first),
484 std::forward<Second>(second)
485 );
486}
487
488
498template<Expression Source>
499constexpr auto repeat(Source&& source, int repeats, int axis = 0) {
501 {repeats, indexing::normalize(axis, source.shape().rank())},
502 std::forward<Source>(source)
503 );
504}
505
506
517template<Expression First, Expression Second>
518constexpr auto concatenate(First&& first, Second&& second, int axis = 0) {
519 if(axis < 0)
520 throw Exception("Negative index not supported in concat");
521
523 {axis},
524 std::forward<First>(first),
525 std::forward<Second>(second)
526 );
527}
528
538template<Expression Source>
539constexpr auto repack(Source&& source) {
541 {},
542 std::forward<Source>(source)
543 );
544}
545
554template<Expression Source, Integral ... Indexes>
555constexpr auto reshape(Source&& source, Indexes ... indexes) {
557 repack(source), indexes...
558 );
559}
560
561} // namespace transformation
562
568
577template<Expression Multiplicand, Expression Multiplier>
578constexpr auto matmul(Multiplicand&& multiplicand, Multiplier&& multiplier, double scale = 1.0) {
580 std::forward<Multiplicand>(multiplicand),
581 std::forward<Multiplier>(multiplier),
582 scale
583 );
584}
585
586} // namespace tannic
587
588#endif // TRANSFORMATIONS_HPP
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
constexpr auto begin()
Definition: shape.hpp:215
constexpr auto end()
Definition: shape.hpp:219
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Expression template for viewing a tensor with a new shape.
Definition: views.hpp:84
Expression template for tensor transformations.
Definition: transformations.hpp:58
Operation operation
Definition: transformations.hpp:60
constexpr type dtype() const
Gets the result data type after promotion.
Definition: transformations.hpp:80
constexpr Strides const & strides() const
Gets the computed strides for the result.
Definition: transformations.hpp:96
std::tuple< typename Trait< Operands >::Reference... > operands
Definition: transformations.hpp:61
constexpr Transformation(Operation operation, typename Trait< Operands >::Reference... operands)
Constructs a Transformation expression.
Definition: transformations.hpp:68
Tensor forward() const
Evaluates the transformation.
Definition: transformations.hpp:112
std::ptrdiff_t offset() const
Gets the data offset (always 0 for new tensors)
Definition: transformations.hpp:104
constexpr Shape const & shape() const
Gets the broadcasted output shape.
Definition: transformations.hpp:88
Defines the core protocol for all expression-like types in the Tannic Tensor Library.
Definition: concepts.hpp:86
Requires a type to be an integral type (e.g., int, std::size_t).
Definition: concepts.hpp:147
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
constexpr auto repeat(Source &&source, int repeats, int axis=0)
Creates a repetition transformation.
Definition: transformations.hpp:499
constexpr auto reshape(Source &&source, Indexes ... indexes)
Creates a view but always repacks the tensor into a contiguous layout.
Definition: transformations.hpp:555
constexpr auto outer(First &&first, Second &&second)
Creates an outer product expression.
Definition: transformations.hpp:480
constexpr auto composition(Outer &&outer, Inner &&inner, double scale)
Creates a composition (matrix multiplication) expression.
Definition: transformations.hpp:463
constexpr auto concatenate(First &&first, Second &&second, int axis=0)
Helper function to create a concatenation transformation.
Definition: transformations.hpp:518
constexpr auto repack(Source &&source)
Creates a repack transformation.
Definition: transformations.hpp:539
Definition: buffer.hpp:41
constexpr auto matmul(Multiplicand &&multiplicand, Multiplier &&multiplier, double scale=1.0)
Matrix multiplication convenience function.
Definition: transformations.hpp:578
std::decay_t< T > Reference
Definition: traits.hpp:28
Transformation composition (Known as Matrix Multiplication) operation.
Definition: transformations.hpp:144
void forward(Tensor const &outer, Tensor const &inner, Tensor &result) const
static constexpr Shape transform(Shape const &first, Shape const &second)
Computes transformed output shape for composition.
Definition: transformations.hpp:237
double scale
Definition: transformations.hpp:145
static constexpr type promote(type inner, type outer)
Promotes two operand types to a common type for composition operations.
Definition: transformations.hpp:216
static constexpr auto promotions
Type promotion rules table.
Definition: transformations.hpp:155
Concatenation operation along a specified axis.
Definition: transformations.hpp:374
void forward(Tensor const &, Tensor const &, Tensor &) const
int axis
Definition: transformations.hpp:375
constexpr auto promote(type first, type second) const
Type promotion for concatenation.
Definition: transformations.hpp:385
constexpr Shape transform(Shape const &first, Shape const &second)
Computes output shape after concatenation.
Definition: transformations.hpp:401
Represents the outer product operation between two vectors.
Definition: transformations.hpp:295
static constexpr Shape transform(Shape const &first, Shape const &second)
Computes output shape for the outer product of two vectors.
Definition: transformations.hpp:322
static constexpr type promote(type first, type second)
Type promotion for the outer product operation.
Definition: transformations.hpp:304
void forward(Tensor const &first, Tensor const &second, Tensor &result) const
Repack operation (makes a tensor contiguous in memory)
Definition: transformations.hpp:426
void forward(Tensor const &source, Tensor &result) const
Copies the tensor data to a contiguous layout if needed.
constexpr type promote(type dtype) const
Type promotion for repack.
Definition: transformations.hpp:432
constexpr Shape transform(Shape const &shape) const
Computes output shape after repack.
Definition: transformations.hpp:441
Repetition operation along a specified axis.
Definition: transformations.hpp:336
constexpr Shape transform(Shape shape) const
Computes output shape after repetition.
Definition: transformations.hpp:360
int axis
Definition: transformations.hpp:338
void forward(Tensor const &, Tensor &) const
int repeats
Definition: transformations.hpp:337
constexpr type promote(type dtype) const
Type promotion for repetition operation.
Definition: transformations.hpp:348