Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
views.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// This file is part of the Tannic Tensor Library.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17
18#ifndef VIEWS_HPP
19#define VIEWS_HPP
20
47#include <utility>
48#include <algorithm>
49#include <numeric>
50#include <vector>
51
52#include "types.hpp"
53#include "traits.hpp"
54#include "shape.hpp"
55#include "strides.hpp"
56#include "concepts.hpp"
57#include "exceptions.hpp"
58
59namespace tannic {
60
61class Tensor;
62
63namespace expression {
64
83template<Expression Source>
84class View {
85public:
86
97 template<Integral... Sizes>
98 constexpr View(typename Trait<Source>::Reference source, Sizes... sizes)
99 : source_(source)
100 {
101 std::array<long long, sizeof...(Sizes)> requested{ static_cast<long long>(sizes)... };
102
103 std::size_t source_elements = std::accumulate(
104 source.shape().begin(), source.shape().end(), 1ULL, std::multiplies<>{}
105 );
106
107 // Step 3. Handle -1 (infer axis)
108 int infer_axis = -1;
109 std::size_t known_product = 1;
110 for (std::size_t i = 0; i < requested.size(); ++i) {
111 if (requested[i] == -1) {
112 if (infer_axis != -1)
113 throw Exception("Only one dimension can be inferred (-1) in view");
114 infer_axis = static_cast<int>(i);
115 } else if (requested[i] < 0) {
116 throw Exception("Invalid negative dimension in view");
117 } else {
118 known_product *= static_cast<std::size_t>(requested[i]);
119 }
120 }
121
122 if (infer_axis != -1) {
123 if (source_elements % known_product != 0)
124 throw Exception("Cannot infer dimension: source elements not divisible");
125 requested[infer_axis] = static_cast<long long>(source_elements / known_product);
126 }
127
128 for (auto v : requested) {
129 shape_.expand(static_cast<std::size_t>(v));
130 }
131
132 std::size_t new_elements = std::accumulate(shape_.begin(), shape_.end(), 1ULL, std::multiplies<>{});
133 if (new_elements != source_elements)
134 throw Exception("Shape mismatch: view must preserve total number of elements");
135
136 strides_ = Strides(shape_);
137 }
138
139
147 constexpr type dtype() const {
148 return source_.dtype();
149 }
150
157 constexpr Shape const& shape() const {
158 return shape_;
159 }
160
167 constexpr Strides const& strides() const {
168 return strides_;
169 }
170
180 std::ptrdiff_t offset() const {
181 return source_.offset();
182 }
183
184 Tensor forward() const;
185
186private:
187 Shape shape_;
188 Strides strides_;
189 typename Trait<Source>::Reference source_;
190};
191
192
210template<Expression Source>
212public:
213
220 constexpr Transpose(typename Trait<Source>::Reference source, std::pair<int, int> dimensions)
221 : shape_(source.shape())
222 , strides_(source.strides())
223 , source_(source)
224 , dimensions_(dimensions) {
225 auto rank = source.shape().rank();
226 std::swap(shape_[indexing::normalize(dimensions.first, rank)], shape_[indexing::normalize(dimensions.second, rank)]);
227 std::swap(strides_[indexing::normalize(dimensions.first, rank)], strides_[indexing::normalize(dimensions.second, rank)]);
228 }
229
236 constexpr type dtype() const {
237 return source_.dtype();
238 }
239
246 constexpr Shape const& shape() const {
247 return shape_;
248 }
249
257 constexpr Strides const& strides() const {
258 return strides_;
259 }
260
268 std::ptrdiff_t offset() const {
269 return source_.offset();
270 }
271
272 Tensor forward() const;
273
274private:
275 Shape shape_;
276 Strides strides_;
277 typename Trait<Source>::Reference source_;
278 std::pair<int, int> dimensions_;
279};
280
281
302template<Expression Source, Integral ... Indexes>
304public:
316 constexpr Permutation(typename Trait<Source>::Reference source, std::tuple<Indexes...> indexes)
317 : source_(source)
318 {
319 if (sizeof...(Indexes) != source_.shape().rank()) {
320 throw Exception("Permutation rank must equal tensor rank");
321 }
322
323 std::apply([&](auto... indexes) {
324 (([&]{
325 int dimension = indexing::normalize(indexes, source_.shape().rank());
326 shape_.expand(source_.shape()[dimension]);
327 strides_.expand(source_.strides()[dimension]);
328 }()), ...);
329 }, indexes);
330 }
331
337 constexpr type dtype() const {
338 return source_.dtype();
339 }
340
347 constexpr Shape const& shape() const {
348 return shape_;
349 }
350
357 constexpr Strides const& strides() const {
358 return strides_;
359 }
360
367 std::ptrdiff_t offset() const {
368 return source_.offset();
369 }
370
371 Tensor forward() const;
372
373private:
374 Shape shape_{};
375 Strides strides_{};
376 typename Trait<Source>::Reference source_;
377};
378
379
380
397template<Expression Source>
399public:
409 template<Integral... Sizes>
410 constexpr Expansion(typename Trait<Source>::Reference source, Sizes... sizes)
411 : source_(source) {
412 std::array<long long, sizeof...(Sizes)> requested{ static_cast<long long>(sizes)... };
413
414 if (requested.size() < source.shape().rank())
415 throw Exception("Expansion target rank must be >= source rank");
416
417 std::size_t offset = requested.size() - source.shape().rank();
418 for (std::size_t dimension = 0; dimension < requested.size(); ++dimension) {
419 long long index = requested[dimension];
420 std::size_t target;
421
422 if (index == -1) {
423 if (dimension < offset) {
424 throw Exception("Cannot use -1 for new leading dimensions");
425 }
426 target = source.shape()[dimension - offset];
427 } else if (index <= 0) {
428 throw Exception("Expansion size must be positive or -1");
429 } else {
430 target = static_cast<std::size_t>(index);
431 }
432
433 if (dimension < offset) {
434 shape_.expand(target);
435 strides_.expand(0);
436 }
437
438 else {
439 if (source.shape()[dimension - offset] == 1 && target > 1) {
440 shape_.expand(target);
441 strides_.expand(0); // broadcast
442 } else if (source.shape()[dimension - offset] == target) {
443 shape_.expand(target);
444 strides_.expand(source.strides()[dimension - offset]);
445 } else {
446 throw Exception("Expansion only allows -1 (keep) or broadcasting singleton dims");
447 }
448 }
449 }
450 }
451
452
459 constexpr type dtype() const {
460 return source_.dtype();
461 }
462
472 constexpr Shape const& shape() const {
473 return shape_;
474 }
475
476
485 constexpr Strides const& strides() const {
486 return strides_;
487 }
488
495 std::ptrdiff_t offset() const {
496 return source_.offset();
497 }
498
499 Tensor forward() const;
500
501private:
502 Shape shape_;
503 Strides strides_;
504 typename Trait<Source>::Reference source_;
505};
506
522template<Expression Source>
523class Squeeze {
524public:
533 constexpr Squeeze(typename Trait<Source>::Reference source)
534 : source_(source) {
535 for (auto dimension = 0; dimension < source.shape().rank(); ++dimension) {
536 if (source.shape()[dimension] != 1) {
537 shape_.expand(source.shape()[dimension]);
538 strides_.expand(source.strides()[dimension]);
539 }
540 }
541 }
542
549 constexpr type dtype() const {
550 return source_.dtype();
551 }
552
563 constexpr Shape const& shape() const {
564 return shape_;
565 }
566
578 constexpr Strides const& strides() const {
579 return strides_;
580 }
581
588 std::ptrdiff_t offset() const {
589 return source_.offset();
590 }
591
592 Tensor forward() const;
593
594private:
595 Shape shape_{};
596 Strides strides_{};
597 typename Trait<Source>::Reference source_;
598};
599
600
601
618template<Expression Source>
620public:
634 template<Integral... Axes>
635 constexpr Unsqueeze(typename Trait<Source>::Reference source, Axes... axes)
636 : source_(source) {
637 auto rank = source.shape().rank();
638 std::vector<std::size_t> normalized{ static_cast<std::size_t>(indexing::normalize(axes, rank + sizeof...(axes)))... };
639 std::sort(normalized.begin(), normalized.end());
640
641 size_t dimensions = rank + normalized.size();
642 size_t index = 0;
643 size_t axis = 0;
644
645 for (auto dimension = 0; dimension < dimensions; ++dimension) {
646 if (axis < normalized.size() && dimension == normalized[axis]) {
647 shape_.expand(1);
648 strides_.expand( (index < source.strides().rank()) ? source.strides()[index] : 1 );
649 ++axis;
650 } else {
651 shape_.expand(source.shape()[index]);
652 strides_.expand(source.strides()[index]);
653 ++index;
654 }
655 }
656 }
657
664 constexpr type dtype() const {
665 return source_.dtype();
666 }
667
679 constexpr Shape const& shape() const {
680 return shape_;
681 }
682
691 constexpr Strides const& strides() const {
692 return strides_;
693 }
694
701 std::ptrdiff_t offset() const {
702 return source_.offset();
703 }
704
705 Tensor forward() const;
706
707private:
708 Shape shape_{};
709 Strides strides_{};
710 typename Trait<Source>::Reference source_;
711};
712
713
731template<Expression Source>
732class Flatten {
733public:
734 constexpr Flatten(typename Trait<Source>::Reference source, int start = 0, int end = -1)
735 : source_(source) {
736 auto rank = source.shape().rank();
737
738 start = indexing::normalize(start, rank);
739 end = indexing::normalize(end, rank);
740
741 if (start > end) {
742 throw Exception("Flatten requires start_dim <= end_dim");
743 }
744
745 for (int dimension = 0; dimension < start; ++dimension) {
746 shape_.expand(source.shape()[dimension]);
747 strides_.expand(source.strides()[dimension]);
748 }
749
750 std::size_t flattened = 1;
751 for (int dimension = start; dimension <= end; ++dimension) {
752 flattened *= source.shape()[dimension];
753 }
754 shape_.expand(flattened);
755 strides_.expand(source.strides()[end]);
756
757 for (int dimension = end + 1; dimension < rank; ++dimension) {
758 shape_.expand(source.shape()[dimension]);
759 strides_.expand(source.strides()[dimension]);
760 }
761 }
762
769 constexpr type dtype() const { return source_.dtype(); }
770
784 constexpr Shape const& shape() const {
785 return shape_;
786 }
787
802 constexpr Strides const& strides() const {
803 return strides_;
804 }
805
812 std::ptrdiff_t offset() const { return source_.offset(); }
813
814 Tensor forward() const;
815
816private:
817 Shape shape_{};
818 Strides strides_{};
819 typename Trait<Source>::Reference source_;
820};
821
822
823
824
825/*
826----------------------------------------------------------------------------------------------------
827*/
828
829
839template<Expression Source, Integral ... Indexes>
840constexpr auto view(Source&& source, Indexes ... indexes) {
841 return View<Source>(
842 std::forward<Source>(source), indexes...
843 );
844}
845
846
856template<Expression Source>
857constexpr auto transpose(Source&& source, int first, int second) {
858 return Transpose<Source>(
859 std::forward<Source>(source),
860 std::make_pair(first, second)
861 );
862}
863
879template<Expression Source, Integral ... Indexes>
880constexpr auto permute(Source&& source, Indexes... indexes) {
881 return Permutation<Source, Indexes...>(
882 std::forward<Source>(source),
883 std::make_tuple(indexes...)
884 );
885}
886
887
914template<Expression Source, Integral... Sizes>
915constexpr auto expand(Source&& source, Sizes... sizes) {
916 return Expansion<Source>(std::forward<Source>(source), sizes...);
917}
918
919
936template<Expression Source>
937constexpr auto squeeze(Source&& source) {
938 return Squeeze<Source>(std::forward<Source>(source));
939}
940
941
961template<Expression Source, Integral... Axes>
962constexpr auto unsqueeze(Source&& source, Axes... axes) {
963 return Unsqueeze<Source>(std::forward<Source>(source), axes...);
964}
965
966
983template<Expression Source>
984constexpr auto flatten(Source&& source, int start = 0, int end = -1) {
985 return Flatten<Source>(std::forward<Source>(source), start, end);
986}
987
988
989} // namespace expression
990
991using expression::view;
998
999} // namespace tannic
1000
1001#endif // VIEWS_HPP
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr auto begin()
Definition: shape.hpp:215
constexpr void expand(size_type size)
Expands the shape's last dimension with a given size.
Definition: shape.hpp:286
constexpr auto end()
Definition: shape.hpp:219
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
constexpr void expand(size_type size)
Expands the strides's last dimension with a given size.
Definition: strides.hpp:267
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Expression template for expanding (broadcasting) singleton dimensions of a tensor.
Definition: views.hpp:398
Tensor forward() const
Definition: tensor.hpp:1201
constexpr Expansion(typename Trait< Source >::Reference source, Sizes... sizes)
Construct an expansion view.
Definition: views.hpp:410
constexpr Shape const & shape() const
Definition: views.hpp:472
constexpr Strides const & strides() const
Definition: views.hpp:485
std::ptrdiff_t offset() const
Definition: views.hpp:495
constexpr type dtype() const
Definition: views.hpp:459
Expression template for flattening a contiguous range of dimensions.
Definition: views.hpp:732
constexpr Strides const & strides() const
Definition: views.hpp:802
Tensor forward() const
Definition: tensor.hpp:1183
constexpr Shape const & shape() const
Definition: views.hpp:784
std::ptrdiff_t offset() const
Definition: views.hpp:812
constexpr type dtype() const
Definition: views.hpp:769
constexpr Flatten(typename Trait< Source >::Reference source, int start=0, int end=-1)
Definition: views.hpp:734
Expression template for reordering tensor dimensions according to a specified permutation.
Definition: views.hpp:303
std::ptrdiff_t offset() const
Definition: views.hpp:367
Tensor forward() const
Definition: tensor.hpp:1189
constexpr Strides const & strides() const
Definition: views.hpp:357
constexpr Permutation(typename Trait< Source >::Reference source, std::tuple< Indexes... > indexes)
Constructs a permuted view of the source tensor.
Definition: views.hpp:316
constexpr Shape const & shape() const
Definition: views.hpp:347
constexpr type dtype() const
Definition: views.hpp:337
Expression template for removing singleton dimensions from a tensor.
Definition: views.hpp:523
std::ptrdiff_t offset() const
Definition: views.hpp:588
constexpr Squeeze(typename Trait< Source >::Reference source)
Construct a squeezed view of the source tensor.
Definition: views.hpp:533
constexpr Shape const & shape() const
Definition: views.hpp:563
constexpr type dtype() const
Definition: views.hpp:549
Tensor forward() const
Definition: tensor.hpp:1171
constexpr Strides const & strides() const
Definition: views.hpp:578
Expression template for transposing two dimensions of a tensor.
Definition: views.hpp:211
Tensor forward() const
Definition: tensor.hpp:1195
constexpr Shape const & shape() const
Definition: views.hpp:246
std::ptrdiff_t offset() const
Definition: views.hpp:268
constexpr Transpose(typename Trait< Source >::Reference source, std::pair< int, int > dimensions)
Construct a transposed view of the source tensor.
Definition: views.hpp:220
constexpr Strides const & strides() const
Definition: views.hpp:257
constexpr type dtype() const
Definition: views.hpp:236
Expression template for inserting singleton dimensions into a tensor.
Definition: views.hpp:619
constexpr Unsqueeze(typename Trait< Source >::Reference source, Axes... axes)
Construct an unsqueezed view of the source tensor.
Definition: views.hpp:635
constexpr Strides const & strides() const
Definition: views.hpp:691
constexpr type dtype() const
Definition: views.hpp:664
Tensor forward() const
Definition: tensor.hpp:1177
constexpr Shape const & shape() const
Definition: views.hpp:679
std::ptrdiff_t offset() const
Definition: views.hpp:701
Expression template for viewing a tensor with a new shape.
Definition: views.hpp:84
constexpr Shape const & shape() const
Definition: views.hpp:157
Tensor forward() const
Definition: tensor.hpp:1165
constexpr type dtype() const
Definition: views.hpp:147
std::ptrdiff_t offset() const
Definition: views.hpp:180
constexpr Strides const & strides() const
Definition: views.hpp:167
constexpr View(typename Trait< Source >::Reference source, Sizes... sizes)
Construct a reshaped view of the source tensor.
Definition: views.hpp:98
Defines the core protocol for all expression-like types in the Tannic Tensor Library.
Definition: concepts.hpp:86
Requires a type to be an integral type (e.g., int, std::size_t).
Definition: concepts.hpp:147
constexpr auto permute(Source &&source, Indexes... indexes)
Creates a permuted view of a tensor or expression.
Definition: views.hpp:880
constexpr auto expand(Source &&source, Sizes... sizes)
Creates an expanded view of a tensor, broadcasting singleton dimensions.
Definition: views.hpp:915
constexpr auto view(Source &&source, Indexes ... indexes)
Creates a reshaped view of a tensor or expression.
Definition: views.hpp:840
constexpr auto transpose(Source &&source, int first, int second)
Creates a transposed view of a tensor or expression by swapping two dimensions.
Definition: views.hpp:857
constexpr auto flatten(Source &&source, int start=0, int end=-1)
Flattens dimensions of a tensor into a single dimension.
Definition: views.hpp:984
constexpr auto squeeze(Source &&source)
Removes all singleton dimensions from a tensor (squeeze).
Definition: views.hpp:937
constexpr auto unsqueeze(Source &&source, Axes... axes)
Inserts singleton dimensions at the specified axes (unsqueeze).
Definition: views.hpp:962
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
Definition: buffer.hpp:41
std::decay_t< T > Reference
Definition: traits.hpp:28