16#ifndef TRANSFORMATIONS_HPP
17#define TRANSFORMATIONS_HPP
46namespace transformation {
57template<
class Operation, Expression ... Operands>
61 std::tuple<typename Trait<Operands>::Reference...>
operands;
113 Tensor result(dtype_, shape_);
114 std::apply([&](
const auto&... arguments) {
115 return operation.forward(arguments.forward()..., result);
132static constexpr auto index(type first, type second) {
133 return static_cast<int>(first) +
static_cast<int>(second) *
static_cast<int>(TYPES);
156 std::array<type, index(TYPES, TYPES)> table{};
159 table[index(int8, int8)] = int32;
160 table[index(int8, int16)] = int32;
161 table[index(int8, int32)] = int32;
162 table[index(int8, int64)] = int64;
164 table[index(int16, int8)] = int32;
165 table[index(int16, int16)] = int32;
166 table[index(int16, int32)] = int32;
167 table[index(int16, int64)] = int64;
169 table[index(int32, int8)] = int32;
170 table[index(int32, int16)] = int32;
171 table[index(int32, int32)] = int64;
172 table[index(int32, int64)] = int64;
174 table[index(int64, int8)] = int64;
175 table[index(int64, int16)] = int64;
176 table[index(int64, int32)] = int64;
177 table[index(int64, int64)] = int64;
180 table[index(int32, float32)] = float32;
181 table[index(float32, int32)] = float32;
182 table[index(int32, float64)] = float64;
183 table[index(float64, int32)] = float64;
186 table[index(float32, float32)] = float32;
187 table[index(float32, float64)] = float64;
188 table[index(float64, float32)] = float64;
189 table[index(float64, float64)] = float64;
216 constexpr static type
promote(type inner, type outer) {
218 if (dtype == unknown)
238 auto first_rank = first.
rank();
239 auto second_rank = second.
rank();
241 if (first_rank == 1 && second_rank == 1) {
242 if (first_rank != 1 | second_rank != 1)
243 throw Exception(
"dimensions must match for dot product");
247 if (first_rank == 1 && second_rank == 2) {
248 if (first[0] != second[0])
249 throw Exception(
"Matrix inner dimensions do not match");
250 if (first[0] != second[0])
251 return Shape{second[0]};
255 if (first_rank == 2 && second_rank == 1) {
256 if (first[1] != second[0])
257 throw Exception(
"Matrix inner dimensions do not match");
258 return Shape{first[0]};
261 if (first_rank < 2 | second_rank < 2)
262 throw Exception(
"Inputs must have rank >= 2");
267 Shape batches = operation::broadcast(first_batches, second_batches);
270 auto K1 = *(first.
end() - 1);
271 auto K2 = *(second.
end() - 2);
272 assert(K1 == K2 &&
"Inner dimensions must match for matmul");
275 auto M = *(first.
end() - 2);
276 auto N = *(second.
end() - 1);
279 std::vector<Shape::size_type> result(batches.
begin(), batches.
end());
282 return Shape(result);
304 static constexpr type
promote(type first, type second) {
305 return static_cast<int>(first) >
static_cast<int>(second) ? first : second;
323 if(first.
rank() != 1 | second.
rank() != 1)
324 throw Exception(
"Outer product of tensors with rank more than 1 not supported");
325 return Shape(first[0], second[0]);
385 constexpr auto promote(type first, type second)
const {
387 throw Exception(
"Cannot concatenate tensors of different dtypes");
402 assert(first.
rank() == second.
rank() &&
"Ranks must match for concatenation");
403 Shape result = first;
405 for (
int dimension = 0; dimension < first.
rank(); ++dimension) {
406 if (dimension ==
axis) {
407 result[dimension] = first[dimension] + second[dimension];
409 if (first[dimension] != second[dimension])
410 throw Exception(
"All dimensions except concat axis must match");
462template<Expression Outer, Expression Inner>
466 std::forward<Outer>(
outer),
467 std::forward<Inner>(inner)
479template<Expression First, Expression Second>
480constexpr auto outer(First&& first, Second&& second) {
483 std::forward<First>(first),
484 std::forward<Second>(second)
498template<Expression Source>
499constexpr auto repeat(Source&& source,
int repeats,
int axis = 0) {
502 std::forward<Source>(source)
517template<Expression First, Expression Second>
518constexpr auto concatenate(First&& first, Second&& second,
int axis = 0) {
520 throw Exception(
"Negative index not supported in concat");
524 std::forward<First>(first),
525 std::forward<Second>(second)
538template<Expression Source>
542 std::forward<Source>(source)
555constexpr auto reshape(Source&& source, Indexes ... indexes) {
557 repack(source), indexes...
577template<Expression Multiplicand, Expression Multiplier>
578constexpr auto matmul(Multiplicand&& multiplicand, Multiplier&& multiplier,
double scale = 1.0) {
580 std::forward<Multiplicand>(multiplicand),
581 std::forward<Multiplier>(multiplier),
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr rank_type rank() const noexcept
Returns the number of dimensions (rank).
Definition: shape.hpp:207
constexpr auto begin()
Definition: shape.hpp:215
constexpr auto end()
Definition: shape.hpp:219
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Expression template for viewing a tensor with a new shape.
Definition: views.hpp:84
Defines the core protocol for all expression-like types in the Tannic Tensor Library.
Definition: concepts.hpp:86
Requires a type to be an integral type (e.g., int, std::size_t).
Definition: concepts.hpp:147
constexpr Index normalize(Index index, Size bound)
Normalize a possibly-negative index into the valid range [0, bound).
Definition: indexing.hpp:87
Definition: buffer.hpp:41
constexpr auto matmul(Multiplicand &&multiplicand, Multiplier &&multiplier, double scale=1.0)
Matrix multiplication convenience function.
Definition: transformations.hpp:578
std::decay_t< T > Reference
Definition: traits.hpp:28