83template<Expression Source>
101 std::array<
long long,
sizeof...(Sizes)> requested{
static_cast<long long>(sizes)... };
103 std::size_t source_elements = std::accumulate(
104 source.shape().begin(), source.shape().end(), 1ULL, std::multiplies<>{}
109 std::size_t known_product = 1;
110 for (std::size_t i = 0; i < requested.size(); ++i) {
111 if (requested[i] == -1) {
112 if (infer_axis != -1)
113 throw Exception(
"Only one dimension can be inferred (-1) in view");
114 infer_axis =
static_cast<int>(i);
115 }
else if (requested[i] < 0) {
116 throw Exception(
"Invalid negative dimension in view");
118 known_product *=
static_cast<std::size_t
>(requested[i]);
122 if (infer_axis != -1) {
123 if (source_elements % known_product != 0)
124 throw Exception(
"Cannot infer dimension: source elements not divisible");
125 requested[infer_axis] =
static_cast<long long>(source_elements / known_product);
128 for (
auto v : requested) {
129 shape_.
expand(
static_cast<std::size_t
>(v));
132 std::size_t new_elements = std::accumulate(shape_.
begin(), shape_.
end(), 1ULL, std::multiplies<>{});
133 if (new_elements != source_elements)
134 throw Exception(
"Shape mismatch: view must preserve total number of elements");
148 return source_.dtype();
181 return source_.offset();
210template<Expression Source>
221 : shape_(source.
shape())
224 , dimensions_(dimensions) {
225 auto rank = source.shape().rank();
237 return source_.dtype();
269 return source_.offset();
278 std::pair<int, int> dimensions_;
319 if (
sizeof...(Indexes) != source_.shape().rank()) {
320 throw Exception(
"Permutation rank must equal tensor rank");
323 std::apply([&](
auto... indexes) {
326 shape_.
expand(source_.shape()[dimension]);
327 strides_.
expand(source_.strides()[dimension]);
338 return source_.dtype();
368 return source_.offset();
376 typename Trait<Source>::Reference source_;
397template<Expression Source>
412 std::array<
long long,
sizeof...(Sizes)> requested{
static_cast<long long>(sizes)... };
414 if (requested.size() < source.shape().rank())
415 throw Exception(
"Expansion target rank must be >= source rank");
417 std::size_t
offset = requested.size() - source.shape().rank();
418 for (std::size_t dimension = 0; dimension < requested.size(); ++dimension) {
419 long long index = requested[dimension];
424 throw Exception(
"Cannot use -1 for new leading dimensions");
426 target = source.shape()[dimension -
offset];
427 }
else if (index <= 0) {
428 throw Exception(
"Expansion size must be positive or -1");
430 target =
static_cast<std::size_t
>(index);
439 if (source.shape()[dimension -
offset] == 1 && target > 1) {
442 }
else if (source.shape()[dimension -
offset] == target) {
446 throw Exception(
"Expansion only allows -1 (keep) or broadcasting singleton dims");
460 return source_.dtype();
496 return source_.offset();
522template<Expression Source>
535 for (
auto dimension = 0; dimension < source.shape().rank(); ++dimension) {
536 if (source.shape()[dimension] != 1) {
537 shape_.
expand(source.shape()[dimension]);
538 strides_.
expand(source.strides()[dimension]);
550 return source_.dtype();
589 return source_.offset();
597 typename Trait<Source>::Reference source_;
618template<Expression Source>
637 auto rank = source.shape().rank();
638 std::vector<std::size_t> normalized{
static_cast<std::size_t
>(
indexing::normalize(axes, rank +
sizeof...(axes)))... };
639 std::sort(normalized.begin(), normalized.end());
641 size_t dimensions = rank + normalized.size();
645 for (
auto dimension = 0; dimension < dimensions; ++dimension) {
646 if (axis < normalized.size() && dimension == normalized[axis]) {
648 strides_.
expand( (index < source.strides().rank()) ? source.strides()[index] : 1 );
651 shape_.
expand(source.shape()[index]);
652 strides_.
expand(source.strides()[index]);
665 return source_.dtype();
702 return source_.offset();
710 typename Trait<Source>::Reference source_;
731template<Expression Source>
736 auto rank = source.shape().rank();
742 throw Exception(
"Flatten requires start_dim <= end_dim");
745 for (
int dimension = 0; dimension < start; ++dimension) {
746 shape_.
expand(source.shape()[dimension]);
747 strides_.
expand(source.strides()[dimension]);
750 std::size_t flattened = 1;
751 for (
int dimension = start; dimension <= end; ++dimension) {
752 flattened *= source.shape()[dimension];
755 strides_.
expand(source.strides()[end]);
757 for (
int dimension = end + 1; dimension < rank; ++dimension) {
758 shape_.
expand(source.shape()[dimension]);
759 strides_.
expand(source.strides()[dimension]);
769 constexpr type
dtype()
const {
return source_.dtype(); }
812 std::ptrdiff_t
offset()
const {
return source_.offset(); }
819 typename Trait<Source>::Reference source_;
839template<Expression Source, Integral ... Indexes>
840constexpr auto view(Source&& source, Indexes ... indexes) {
842 std::forward<Source>(source), indexes...
856template<Expression Source>
857constexpr auto transpose(Source&& source,
int first,
int second) {
859 std::forward<Source>(source),
860 std::make_pair(first, second)
880constexpr auto permute(Source&& source, Indexes... indexes) {
882 std::forward<Source>(source),
883 std::make_tuple(indexes...)
915constexpr auto expand(Source&& source, Sizes... sizes) {
936template<Expression Source>
962constexpr auto unsqueeze(Source&& source, Axes... axes) {
983template<Expression Source>
984constexpr auto flatten(Source&& source,
int start = 0,
int end = -1) {
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr auto begin()
Definition: shape.hpp:215
constexpr void expand(size_type size)
Expands the shape's last dimension with a given size.
Definition: shape.hpp:286
constexpr auto end()
Definition: shape.hpp:219
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
constexpr void expand(size_type size)
Expands the strides's last dimension with a given size.
Definition: strides.hpp:267
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Expression template for expanding (broadcasting) singleton dimensions of a tensor.
Definition: views.hpp:398
Tensor forward() const
Definition: tensor.hpp:1201
constexpr Expansion(typename Trait< Source >::Reference source, Sizes... sizes)
Construct an expansion view.
Definition: views.hpp:410
constexpr Shape const & shape() const
Definition: views.hpp:472
constexpr Strides const & strides() const
Definition: views.hpp:485
std::ptrdiff_t offset() const
Definition: views.hpp:495
constexpr type dtype() const
Definition: views.hpp:459
Expression template for flattening a contiguous range of dimensions.
Definition: views.hpp:732
constexpr Strides const & strides() const
Definition: views.hpp:802
Tensor forward() const
Definition: tensor.hpp:1183
constexpr Shape const & shape() const
Definition: views.hpp:784
std::ptrdiff_t offset() const
Definition: views.hpp:812
constexpr type dtype() const
Definition: views.hpp:769
constexpr Flatten(typename Trait< Source >::Reference source, int start=0, int end=-1)
Definition: views.hpp:734
Expression template for reordering tensor dimensions according to a specified permutation.
Definition: views.hpp:303
std::ptrdiff_t offset() const
Definition: views.hpp:367
Tensor forward() const
Definition: tensor.hpp:1189
constexpr Strides const & strides() const
Definition: views.hpp:357
constexpr Permutation(typename Trait< Source >::Reference source, std::tuple< Indexes... > indexes)
Constructs a permuted view of the source tensor.
Definition: views.hpp:316
constexpr Shape const & shape() const
Definition: views.hpp:347
constexpr type dtype() const
Definition: views.hpp:337
Expression template for removing singleton dimensions from a tensor.
Definition: views.hpp:523
std::ptrdiff_t offset() const
Definition: views.hpp:588
constexpr Squeeze(typename Trait< Source >::Reference source)
Construct a squeezed view of the source tensor.
Definition: views.hpp:533
constexpr Shape const & shape() const
Definition: views.hpp:563
constexpr type dtype() const
Definition: views.hpp:549
Tensor forward() const
Definition: tensor.hpp:1171
constexpr Strides const & strides() const
Definition: views.hpp:578
Expression template for transposing two dimensions of a tensor.
Definition: views.hpp:211
Tensor forward() const
Definition: tensor.hpp:1195
constexpr Shape const & shape() const
Definition: views.hpp:246
std::ptrdiff_t offset() const
Definition: views.hpp:268
constexpr Transpose(typename Trait< Source >::Reference source, std::pair< int, int > dimensions)
Construct a transposed view of the source tensor.
Definition: views.hpp:220
constexpr Strides const & strides() const
Definition: views.hpp:257
constexpr type dtype() const
Definition: views.hpp:236
Expression template for inserting singleton dimensions into a tensor.
Definition: views.hpp:619
constexpr Unsqueeze(typename Trait< Source >::Reference source, Axes... axes)
Construct an unsqueezed view of the source tensor.
Definition: views.hpp:635
constexpr Strides const & strides() const
Definition: views.hpp:691
constexpr type dtype() const
Definition: views.hpp:664
Tensor forward() const
Definition: tensor.hpp:1177
constexpr Shape const & shape() const
Definition: views.hpp:679
std::ptrdiff_t offset() const
Definition: views.hpp:701
Expression template for viewing a tensor with a new shape.
Definition: views.hpp:84
constexpr Shape const & shape() const
Definition: views.hpp:157
Tensor forward() const
Definition: tensor.hpp:1165
constexpr type dtype() const
Definition: views.hpp:147
std::ptrdiff_t offset() const
Definition: views.hpp:180
constexpr Strides const & strides() const
Definition: views.hpp:167
constexpr View(typename Trait< Source >::Reference source, Sizes... sizes)
Construct a reshaped view of the source tensor.
Definition: views.hpp:98
Defines the core protocol for all expression-like types in the Tannic Tensor Library.
Definition: concepts.hpp:86
Requires a type to be an integral type (e.g., int, std::size_t).
Definition: concepts.hpp:147
constexpr auto permute(Source &&source, Indexes... indexes)
Creates a permuted view of a tensor or expression.
Definition: views.hpp:880
constexpr auto expand(Source &&source, Sizes... sizes)
Creates an expanded view of a tensor, broadcasting singleton dimensions.
Definition: views.hpp:915
constexpr auto view(Source &&source, Indexes ... indexes)
Creates a reshaped view of a tensor or expression.
Definition: views.hpp:840
constexpr auto transpose(Source &&source, int first, int second)
Creates a transposed view of a tensor or expression by swapping two dimensions.
Definition: views.hpp:857
constexpr auto flatten(Source &&source, int start=0, int end=-1)
Flattens dimensions of a tensor into a single dimension.
Definition: views.hpp:984
constexpr auto squeeze(Source &&source)
Removes all singleton dimensions from a tensor (squeeze).
Definition: views.hpp:937
constexpr auto unsqueeze(Source &&source, Axes... axes)
Inserts singleton dimensions at the specified axes (unsqueeze).
Definition: views.hpp:962
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
Definition: buffer.hpp:41
std::decay_t< T > Reference
Definition: traits.hpp:28