Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
slices.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// This file is part of the Tannic Tensor Library.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17
18#ifndef SLICES_HPP
19#define SLICES_HPP
20
52#include <tuple>
53#include <utility>
54#include <cstddef>
55#include <vector>
56
57#include "types.hpp"
58#include "traits.hpp"
59#include "shape.hpp"
60#include "strides.hpp"
61#include "indexing.hpp"
62
63namespace tannic {
64class Tensor;
65}
66
67namespace tannic::expression {
68
86template <Expression Source, class... Indexes>
87class Slice {
88public:
89
98 constexpr Slice(typename Trait<Source>::Reference source, std::tuple<Indexes...> indexes)
99 : dtype_(source.dtype())
100 , source_(source)
101 , indexes_(indexes)
102 {
103 if (source.rank() == 0) {
104 shape_ = Shape{};
105 strides_ = Strides{};
106 offset_ = 0;
107 }
108
109 else {
110 std::array<Shape::size_type, Shape::limit> shape{};
111 std::array<Strides::size_type, Strides::limit> strides{};
112 Shape::rank_type dimension = 0;
114 offset_ = 0;
115 auto process = [&](const auto& argument) {
116 using Argument = std::decay_t<decltype(argument)>;
117 if constexpr (std::is_same_v<Argument, indexing::Range>) {
118 auto range = normalize(argument, source.shape()[dimension]);
119 auto size = range.stop - range.start;
120 shape[rank] = size;
121 strides[rank] = source.strides()[dimension];
122 offset_ += range.start * source.strides()[dimension] * dsizeof(dtype_);
123 rank++; dimension++;
124 }
125
126 else if constexpr (std::is_integral_v<Argument>) {
127 auto index = indexing::normalize(argument, source.shape()[dimension]);
128 offset_ += index * source.strides()[dimension] * dsizeof(dtype_);
129 dimension++;
130 }
131
132 else {
133 throw Exception("Unknown index type");
134 }
135 };
136
137 std::apply([&](const auto&... arguments) {
138 (process(arguments), ...);
139 }, indexes);
140
141 while (dimension < source.rank()) {
142 shape[rank] = source.shape()[dimension];
143 strides[rank] = source.strides()[dimension];
144 rank++; dimension++;
145 }
146
147 shape_ = Shape(shape.begin(), shape.begin() + rank);
148 strides_ = Strides(strides.begin(), strides.begin() + rank);
149 }
150 }
151
157 template<Integral Index>
158 constexpr auto operator[](Index index) const {
159 return Slice<Source, Indexes..., Index>(source_, std::tuple_cat(indexes_, std::make_tuple(index)));
160 }
161
167 constexpr auto operator[](indexing::Range range) const {
168 return Slice<Source, Indexes..., indexing::Range>(source_, std::tuple_cat(indexes_, std::make_tuple(range)));
169 }
170
189 template<typename T>
190 void operator=(T value);
191
192
193
216 template<typename T>
217 bool operator==(T value) const;
218
229 constexpr auto dtype() const {
230 return dtype_;
231 }
232
245 constexpr auto rank() const {
246 return shape_.rank();
247 }
248
272 constexpr Shape const& shape() const {
273 return shape_;
274 }
275
276
294 constexpr Strides const& strides() const {
295 return strides_;
296 }
297
315 std::ptrdiff_t offset() const {
316 return offset_ + source_.offset();
317 }
318
319 std::byte* bytes() {
320 return source_.bytes() + offset_;
321 }
322
323 std::byte const* bytes() const {
324 return source_.bytes() + offset_;
325 }
326
327 Tensor forward() const;
328
329 void assign(std::byte const* value, std::ptrdiff_t offset);
330 void assign(bool const*, std::ptrdiff_t);
331 bool compare(std::byte const* value, std::ptrdiff_t offset) const;
332
333private:
334 type dtype_;
335 Shape shape_;
336 Strides strides_;
337 std::ptrdiff_t offset_;
338 typename Trait<Source>::Reference source_;
339 std::tuple<Indexes...> indexes_;
340};
341
342template<typename T>
343inline std::byte const* tobytes(T const& reference) {
344 return reinterpret_cast<std::byte const*>(&reference);
345}
346
347template <Expression Source, class... Indexes>
348template <typename T>
350 auto copy = [this](std::byte const* value, std::ptrdiff_t offset) {
351 if(rank() == 0) {
352 assign(value, offset);
353 return;
354 }
355 std::vector<std::size_t> indexes(rank(), 0);
356 bool done = false;
357
358 while (!done) {
359 std::size_t position = offset;
360 for (auto dimension = 0; dimension < rank(); ++dimension) {
361 position += indexes[dimension] * strides_[dimension] * dsizeof(dtype_);
362 }
363
364 assign(value, position);
365 done = true;
366 for (int dimension = rank() - 1; dimension >= 0; --dimension) {
367 if (++indexes[dimension] < shape_[dimension]) {
368 done = false;
369 break;
370 }
371 indexes[dimension] = 0;
372 }
373 }
374 };
375
376 switch (dtype_) {
377 case int8: { int8_t casted = value; copy(tobytes(casted), offset()); break; }
378 case int16: { int16_t casted = value; copy(tobytes(casted), offset()); break; }
379 case int32: { int32_t casted = value; copy(tobytes(casted), offset()); break; }
380 case int64: { int64_t casted = value; copy(tobytes(casted), offset()); break; }
381 case float32: { float casted = value; copy(tobytes(casted), offset()); break; }
382 case float64: { double casted = value; copy(tobytes(casted), offset()); break; }
383 default: throw Exception("Unsupported dtype for assignment");
384 }
385}
386
387template <Expression Source, class... Indexes>
388template <typename T>
390 if (rank() != 0)
391 throw Exception("Cannot compare an scalar to a non scalar slice");
392
393 switch (dtype_) {
394 case int8: { int8_t casted = value; return compare(tobytes(casted), offset()); }
395 case int16: { int16_t casted = value; return compare(tobytes(casted), offset()); }
396 case int32: { int32_t casted = value; return compare(tobytes(casted), offset()); }
397 case int64: { int64_t casted = value; return compare(tobytes(casted), offset()); }
398 case float32: { float casted = value; return compare(tobytes(casted), offset()); }
399 case float64: { double casted = value; return compare(tobytes(casted), offset()); }
400 default: throw Exception("Unsupported dtype for comparison");
401 }
402}
403
404} // namespace tannic::expression
405
406#endif // SLICES_HPP
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
constexpr auto begin()
Definition: shape.hpp:215
uint8_t rank_type
Type used for rank (number of dimensions).
Definition: shape.hpp:84
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
constexpr auto begin()
Definition: strides.hpp:200
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Expression template representing a tensor slice or subview.
Definition: slices.hpp:87
constexpr Slice(typename Trait< Source >::Reference source, std::tuple< Indexes... > indexes)
Create a slice from a source expression and an index tuple.
Definition: slices.hpp:98
void operator=(T value)
Assigns a scalar value to all elements in the slice.
Definition: slices.hpp:349
std::ptrdiff_t offset() const
Returns the byte offset from the source tensor's data pointer.
Definition: slices.hpp:315
constexpr Shape const & shape() const
Returns the shape (size in each dimension) of this slice.
Definition: slices.hpp:272
constexpr auto operator[](indexing::Range range) const
Index into the slice with a range.
Definition: slices.hpp:167
std::byte * bytes()
Definition: slices.hpp:319
Tensor forward() const
Definition: tensor.hpp:1207
constexpr auto rank() const
Returns the number of dimensions in this slice.
Definition: slices.hpp:245
bool operator==(T value) const
Compares a scalar value to the element in a rank-0 slice.
Definition: slices.hpp:389
std::byte const * bytes() const
Definition: slices.hpp:323
constexpr auto operator[](Index index) const
Index into the slice with an integer.
Definition: slices.hpp:158
bool compare(std::byte const *value, std::ptrdiff_t offset) const
Definition: tensor.hpp:1238
constexpr Strides const & strides() const
Returns the memory strides for this slice.
Definition: slices.hpp:294
void assign(std::byte const *value, std::ptrdiff_t offset)
Definition: tensor.hpp:1213
constexpr auto dtype() const
Returns the runtime data type of elements in this slice.
Definition: slices.hpp:229
Defines the core protocol for all expression-like types in the Tannic Tensor Library.
Definition: concepts.hpp:86
Definition: comparisons.hpp:69
std::byte const * tobytes(T const &reference)
Definition: slices.hpp:343
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
Definition: buffer.hpp:41
constexpr std::size_t dsizeof(type type)
Returns the size in bytes of a given tensor data type.
Definition: types.hpp:93
std::decay_t< T > Reference
Definition: traits.hpp:28
Represents a half-open interval [start, stop) for slicing.
Definition: indexing.hpp:56
int start
Definition: indexing.hpp:57
int stop
Definition: indexing.hpp:58