Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
strides.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// This file is part of the Tannic Tensor Library.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17
18#ifndef STRIDES_HPP
19#define STRIDES_HPP
20
44#include <array>
45#include <cstddef>
46#include <stdexcept>
47#include <ostream>
48
49#include "concepts.hpp"
50#include "shape.hpp"
51#include "indexing.hpp"
52#include "exceptions.hpp"
53
54namespace tannic {
55
87class Strides {
88public:
89 static constexpr uint8_t limit = 8;
90
91public:
92 using rank_type = uint8_t;
93 using size_type = int64_t;
94
95
97 constexpr Strides() noexcept = default;
98
113 template<Integral... Sizes>
114 constexpr Strides(Sizes... sizes)
115 : sizes_{static_cast<size_type>(sizes)...}
116 , rank_(sizeof...(sizes)) {
117 if (rank_ > limit)
118 throw Exception("Strides rank limit exceeded");
119 }
120
137 template<Iterator Iterator>
139 size_type dimension = 0;
140 for (auto iterator = begin; iterator != end; ++iterator) {
141 assert(dimension < limit && "Strides rank limit exceeded");
142 sizes_[dimension++] = static_cast<size_type>(*iterator);
143 }
144 rank_ = dimension;
145 if (rank_ > limit)
146 throw Exception("Strides rank limit exceeded");
147 }
148
161 constexpr Strides(const Shape& shape) {
162 rank_ = shape.rank();
163 if (rank_ == 0) return;
164
165 sizes_[rank_ - 1] = 1;
166 for (int size = rank_ - 2; size >= 0; --size) {
167 sizes_[size] = sizes_[size + 1] * shape[size + 1];
168 }
169 if (rank_ > limit)
170 throw Exception("Strides rank limit exceeded");
171 }
172
173public:
178 constexpr size_type* address() noexcept {
179 return sizes_.data();
180 }
181
186 constexpr size_type const* address() const noexcept {
187 return sizes_.data();
188 }
189
194 constexpr auto rank() const noexcept {
195 return rank_;
196 }
197
200 constexpr auto begin() {
201 return sizes_.begin();
202 }
203
204 constexpr auto end() {
205 return sizes_.begin() + rank_;
206 }
207
208 constexpr auto begin() const {
209 return sizes_.begin();
210 }
211
212 constexpr auto end() const {
213 return sizes_.begin() + rank_;
214 }
215
216 constexpr auto cbegin() const {
217 return sizes_.cbegin();
218 }
219
220 constexpr auto cend() const {
221 return sizes_.cbegin() + rank_;
222 }
224
229 constexpr auto front() const {
230 return sizes_.front();
231 }
232
237 constexpr auto back() const {
238 return sizes_[rank_];
239 }
240
247 template<Integral Index>
248 constexpr auto const& operator[](Index index) const {
249 return sizes_[indexing::normalize(index, rank())];
250 }
251
258 template<Integral Index>
259 constexpr auto& operator[](Index index) {
260 return sizes_[indexing::normalize(index, rank())];
261 }
262
267 constexpr void expand(size_type size) {
268 if (rank_ + 1 > limit)
269 throw Exception("Strides rank limit exceeded");
270 sizes_[rank_] = size;
271 rank_ += 1;
272 }
273
274private:
275 rank_type rank_{0};
276 std::array<size_type, limit> sizes_{};
277};
278
283constexpr bool operator==(Strides const& first, Strides const& second) {
284 if (first.rank() != second.rank()) return false;
285 for (Strides::rank_type dimension = 0; dimension < first.rank(); ++dimension) {
286 if (first[dimension] != second[dimension]) return false;
287 }
288 return true;
289}
290
291inline std::ostream& operator<<(std::ostream& os, Strides const& strides) {
292 os << "Strides(";
293 for (Strides::rank_type dimension = 0; dimension < strides.rank(); ++dimension) {
294 os << static_cast<unsigned int>(strides[dimension]);
295 if (dimension + 1 < strides.rank()) {
296 os << ", ";
297 }
298 }
299 os << ")";
300 return os;
301}
302
303} // namespace tannic
304
305#endif // STRIDES_HPP
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
static constexpr uint8_t limit
Definition: strides.hpp:89
constexpr auto cend() const
Definition: strides.hpp:220
constexpr auto const & operator[](Index index) const
Accesses a stride value by index (const).
Definition: strides.hpp:248
constexpr auto begin()
Definition: strides.hpp:200
constexpr auto end() const
Definition: strides.hpp:212
constexpr size_type * address() noexcept
Returns a pointer to the underlying data (non-const).
Definition: strides.hpp:178
constexpr auto rank() const noexcept
Returns the number of dimensions (rank).
Definition: strides.hpp:194
constexpr auto cbegin() const
Definition: strides.hpp:216
constexpr size_type const * address() const noexcept
Returns a pointer to the underlying data (const).
Definition: strides.hpp:186
uint8_t rank_type
Type used for rank (number of dimensions).
Definition: strides.hpp:92
constexpr Strides(const Shape &shape)
Constructs strides from a shape assuming row-major layout.
Definition: strides.hpp:161
constexpr auto front() const
Returns the first stride value.
Definition: strides.hpp:229
constexpr auto & operator[](Index index)
Accesses a stride value by index (non-const).
Definition: strides.hpp:259
constexpr Strides(Iterator begin, Iterator end)
Constructs strides from a pair of iterators.
Definition: strides.hpp:138
constexpr void expand(size_type size)
Expands the strides's last dimension with a given size.
Definition: strides.hpp:267
constexpr auto begin() const
Definition: strides.hpp:208
constexpr auto back() const
Returns the last stride value.
Definition: strides.hpp:237
int64_t size_type
Type used for size and shape dimensions.
Definition: strides.hpp:93
constexpr auto end()
Definition: strides.hpp:204
constexpr Strides() noexcept=default
Default constructor (rank 0).
Requires a type to be an integral type (e.g., int, std::size_t).
Definition: concepts.hpp:147
Requires a type to satisfy the C++20 std::input_iterator concept.
Definition: concepts.hpp:140
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
Definition: buffer.hpp:41
std::ostream & operator<<(std::ostream &os, Shape const &shape)
Definition: shape.hpp:313