Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
complex.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// This file is part of the Tannic Tensor Library.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17
18#ifndef COMPLEX_HPP
19#define COMPLEX_HPP
20
37#include "concepts.hpp"
38#include "types.hpp"
39#include "shape.hpp"
40#include "strides.hpp"
41#include "traits.hpp"
42#include <cassert>
43
44namespace tannic {
45
46class Tensor;
47
48} namespace tannic::expression {
49
57struct Cartesian {
58 static void forward(Tensor const&, Tensor const&, Tensor&);
59};
60
68struct Polar {
69 static void forward(Tensor const&, Tensor const&, Tensor&);
70};
71
72
85template<class Coordinates, Expression ... Sources>
87
88// Single-source specialization (interleaved real/imaginary data)
89template<class Coordinates, Expression Source>
90class Complexification<Coordinates, Source> {
91public:
93
99 : source(source)
100 {
101 switch (source.dtype()) {
102 case float32: dtype_ = complex64; break;
103 case float64: dtype_ = complex128; break;
104 default:
105 throw Exception("Complex view error: source tensor dtype must be float32 or float64");
106 }
107
108 if (source.strides()[-1] == 1 && source.strides()[-2] == 2) {
109 assert(source.shape().back() == 2 &&
110 "Complex view error: last dimension must be size 2 (real + imag).");
111 shape_ = Shape(source.shape().begin(), source.shape().end() - 1);
112 strides_ = Strides(source.strides().begin(), source.strides().end() - 1);
113 strides_[-1] = 1;
114 } else {
115 throw Exception(
116 "Complex view error: source tensor is not contiguous in last two dimensions. "
117 "Cannot create complex view safely."
118 );
119 }
120 }
121
131 constexpr type dtype() const {
132 return dtype_;
133 }
134
151 constexpr Shape const& shape() const {
152 return shape_;
153 }
154
155
165 constexpr Strides const& strides() const {
166 return strides_;
167 }
168
169
175 std::ptrdiff_t offset() const {
176 return source.offset();
177 }
178
179 Tensor forward() const;
180
181private:
182 type dtype_;
183 Shape shape_;
184 Strides strides_;
185};
186
187// Dual-source specialization (separate real/imaginary tensors)
188template<class Coordinates, Expression Real, Expression Imaginary>
189class Complexification<Coordinates, Real, Imaginary> {
190public:
193
198 constexpr Complexification(typename Trait<Real>::Reference real, typename Trait<Imaginary>::Reference imaginary)
199 : real(real)
200 , imaginary(imaginary)
201 {
202 if (real.shape() != imaginary.shape() | real.strides() != imaginary.strides())
203 throw Exception("Complexification error: real and imaginary part layouts must match");
204
205 if (real.dtype() == float64 || imaginary.dtype() == float64) {
206 dtype_ = complex128;
207 }
208
209 else {
210 dtype_ = complex64;
211 }
212 }
213
218 constexpr type dtype() const {
219 return dtype_;
220 }
221
230 constexpr Shape const& shape() const {
231 return real.shape();
232 }
233
242 constexpr Strides const& strides() const {
243 return real.strides();
244 }
245
250 std::ptrdiff_t offset() const {
251 return 0;
252 }
253
254 Tensor forward() const;
255
256private:
257 type dtype_;
258};
259
260
289template<Expression Source>
291public:
293
299 : source(source)
300 {
301 switch (source.dtype()) {
302 case complex64: dtype_ = float32; break;
303 case complex128: dtype_ = float64; break;
304 default:
305 assert(false &&
306 "Real view error: source tensor dtype must be complex64 or complex128");
307 }
308
309 if (source.strides()[-1] == 1) {
310 shape_ = Shape(source.shape().begin(), source.shape().end());
311 shape_.expand(2);
312
313 strides_ = Strides(source.strides().begin(), source.strides().end());
314 strides_.expand(1);
315 strides_[-2] = 2;
316 } else {
317 throw Exception("Real view error: source tensor is not in interleaved real/imag format");
318 }
319
320 }
321
331 constexpr type dtype() const {
332 return dtype_;
333 }
334
335
351 constexpr Shape const& shape() const {
352 return shape_;
353 }
354
364 constexpr Strides const& strides() const {
365 return strides_;
366 }
367
368 std::ptrdiff_t offset() const {
369 return source.offset();
370 }
371
372 Tensor forward() const;
373
374private:
375 type dtype_;
376 Shape shape_;
377 Strides strides_;
378};
379
380
401template<Expression Real>
402constexpr auto complexify(Real&& real) {
403 return Complexification<Cartesian, Real>{std::forward<Real>(real)};
404}
405
425template<Expression Real, Expression Imaginary>
426constexpr auto complex(Real&& real, Imaginary&& imaginary) {
428 std::forward<Real>(real),
429 std::forward<Imaginary>(imaginary)
430 };
431}
432
452template<Expression Magnitude, Expression Angle>
453constexpr auto polar(Magnitude&& rho, Angle&& theta) {
455 std::forward<Magnitude>(rho),
456 std::forward<Angle>(theta)
457 };
458}
459
485template<Expression Complex>
486constexpr auto realify(Complex&& complex) {
487 return Realification<Complex>{std::forward<Complex>(complex)};
488}
489
490} namespace tannic {
491
496
497} // namespace tannic
498
499#endif // COMPLEX_HPP
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
constexpr void expand(size_type size)
Expands the shape's last dimension with a given size.
Definition: shape.hpp:286
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
constexpr void expand(size_type size)
Expands the strides's last dimension with a given size.
Definition: strides.hpp:267
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Trait< Imaginary >::Reference imaginary
Definition: complex.hpp:192
std::ptrdiff_t offset() const
Returns the offset of the complex expression.
Definition: complex.hpp:250
constexpr Shape const & shape() const
Returns the shape of the complex tensor.
Definition: complex.hpp:230
constexpr Strides const & strides() const
Returns the strides of the complex tensor.
Definition: complex.hpp:242
Trait< Real >::Reference real
Definition: complex.hpp:191
constexpr Complexification(typename Trait< Real >::Reference real, typename Trait< Imaginary >::Reference imaginary)
Returns the complex dtype of the combined tensor.
Definition: complex.hpp:198
constexpr type dtype() const
Returns the complex dtype of the combined tensor.
Definition: complex.hpp:218
constexpr type dtype() const
Returns the complex dtype of the view.
Definition: complex.hpp:131
constexpr Strides const & strides() const
Returns the strides of the complex view.
Definition: complex.hpp:165
Trait< Source >::Reference source
Definition: complex.hpp:92
std::ptrdiff_t offset() const
Returns the offset of the view.
Definition: complex.hpp:175
constexpr Complexification(Trait< Source >::Reference source)
Constructs complex view from interleaved data.
Definition: complex.hpp:98
constexpr Shape const & shape() const
Returns the shape of the complex view.
Definition: complex.hpp:151
Creates a complex tensor view from real components.
Definition: complex.hpp:86
Creates a real-valued view of complex tensor data.
Definition: complex.hpp:290
constexpr Shape const & shape() const
Returns the shape of the real view.
Definition: complex.hpp:351
constexpr type dtype() const
Returns the real dtype of the view.
Definition: complex.hpp:331
constexpr Realification(Trait< Source >::Reference source)
Definition: complex.hpp:298
Trait< Source >::Reference source
Definition: complex.hpp:292
std::ptrdiff_t offset() const
Definition: complex.hpp:368
constexpr Strides const & strides() const
Returns the strides of the real view.
Definition: complex.hpp:364
Tensor forward() const
Definition: tensor.hpp:1264
Defines the core protocol for all expression-like types in the Tannic Tensor Library.
Definition: concepts.hpp:86
Definition: comparisons.hpp:69
constexpr auto realify(Complex &&complex)
Creates a real-valued view of complex tensor data.
Definition: complex.hpp:486
constexpr auto complex(Real &&real, Imaginary &&imaginary)
Creates complex tensor from separate real and imaginary tensors
Definition: complex.hpp:426
constexpr auto polar(Magnitude &&rho, Angle &&theta)
Creates complex tensor from polar coordinates (magnitude/angle)
Definition: complex.hpp:453
constexpr auto complexify(Real &&real)
Creates a complex tensor view from interleaved real/imaginary data.
Definition: complex.hpp:402
Definition: buffer.hpp:41
std::decay_t< T > Reference
Definition: traits.hpp:28
Tag type for Cartesian (real/imaginary) complex number representation.
Definition: complex.hpp:57
static void forward(Tensor const &, Tensor const &, Tensor &)
Tag type for Polar (magnitude/angle) complex number representation.
Definition: complex.hpp:68
static void forward(Tensor const &, Tensor const &, Tensor &)