Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
comparisons.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// This file is part of the Tannic Tensor Library.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17
18#ifndef COMPARISONS_HPP
19#define COMPARISONS_HPP
20
65#include "concepts.hpp"
66#include "shape.hpp"
67#include "tensor.hpp"
68
70
71
83template<class Criteria, Expression First, Expression Second>
85public:
86 Criteria criteria;
89
96 , first(first)
97 , second(second)
98 , shape_(first.shape())
99 , strides_(shape_)
100 {
101 if(first.shape() != second.shape())
102 throw Exception("Cannot compare tensors of different shape");
103 }
104
105 constexpr type dtype() const {
106 return boolean;
107 }
108
109 constexpr Shape const& shape() const {
110 return shape_;
111 }
112
113 constexpr Strides const& strides() const {
114 return strides_;
115 }
116
117 constexpr std::ptrdiff_t offset() const {
118 return 0;
119 }
120
121 Tensor forward() const {
122 Tensor result(boolean, shape_, strides_, 0);
123 criteria.forward(first, second, result);
124 return result;
125 }
126
127
128private:
129 Shape shape_;
130 Strides strides_;
131
132};
133
134struct EQ {
135 void forward(Tensor const&, Tensor const&, Tensor&) const;
136};
137
138struct NE {
139 void forward(Tensor const&, Tensor const&, Tensor&) const;
140};
141
142struct GT {
143 void forward(Tensor const&, Tensor const&, Tensor&) const;
144};
145
146struct GE {
147 void forward(Tensor const&, Tensor const&, Tensor&) const;
148};
149
150struct LT {
151 void forward(Tensor const&, Tensor const&, Tensor&) const;
152};
153
154struct LE {
155 void forward(Tensor const&, Tensor const&, Tensor&) const;
156};
157
158template<Expression First, Expression Second>
159constexpr auto operator==(First&& lhs, Second&& rhs) {
160 return Comparison<EQ, First, Second>({}, lhs, rhs);
161}
162
163template<Expression First, Expression Second>
164constexpr auto operator!=(First&& lhs, Second&& rhs) {
165 return Comparison<NE, First, Second>({}, lhs, rhs);
166}
167
168template<Expression First, Expression Second>
169constexpr auto operator<(First&& lhs, Second&& rhs) {
170 return Comparison<LT, First, Second>({}, lhs, rhs);
171}
172
173template<Expression First, Expression Second>
174constexpr auto operator<=(First&& lhs, Second&& rhs) {
175 return Comparison<LE, First, Second>({}, lhs, rhs);
176}
177
178template<Expression First, Expression Second>
179constexpr auto operator>(First&& lhs, Second&& rhs) {
180 return Comparison<GT, First, Second>({}, lhs, rhs);
181}
182
183template<Expression First, Expression Second>
184constexpr auto operator>=(First&& lhs, Second&& rhs) {
185 return Comparison<GE, First, Second>({}, lhs, rhs);
186}
187
209bool allclose(Tensor const& first, Tensor const& second, double rtol = 1e-5f, double atol = 1e-8f);
210
211} namespace tannic {
212
213using expression::operator==;
214using expression::operator!=;
215using expression::operator<;
216using expression::operator<=;
217using expression::operator>;
218using expression::operator>=;
220
221} // namespace tannic
222
223#endif // COMPARISONS_HPP
A simple generic exception type for the Tannic Tensor Library.
Definition: exceptions.hpp:44
Represents the shape (dimensions) of an tensor-like expression.
Definition: shape.hpp:79
Represents the memory strides associated with a tensor shape.
Definition: strides.hpp:87
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Expression template for element-wise tensor comparisons.
Definition: comparisons.hpp:84
constexpr Strides const & strides() const
Definition: comparisons.hpp:113
Tensor forward() const
Definition: comparisons.hpp:121
constexpr type dtype() const
Definition: comparisons.hpp:105
Trait< First >::Reference first
Definition: comparisons.hpp:87
constexpr std::ptrdiff_t offset() const
Definition: comparisons.hpp:117
Criteria criteria
Definition: comparisons.hpp:86
constexpr Shape const & shape() const
Definition: comparisons.hpp:109
constexpr Comparison(Criteria criteria, typename Trait< First >::Reference first, typename Trait< Second >::Reference second)
Constructs a comparison expression.
Definition: comparisons.hpp:94
Trait< Second >::Reference second
Definition: comparisons.hpp:88
Definition: comparisons.hpp:69
constexpr auto operator==(First &&lhs, Second &&rhs)
Definition: comparisons.hpp:159
constexpr auto operator>=(First &&lhs, Second &&rhs)
Definition: comparisons.hpp:184
constexpr auto operator!=(First &&lhs, Second &&rhs)
Definition: comparisons.hpp:164
constexpr auto operator>(First &&lhs, Second &&rhs)
Definition: comparisons.hpp:179
constexpr auto operator<=(First &&lhs, Second &&rhs)
Definition: comparisons.hpp:174
bool allclose(Tensor const &first, Tensor const &second, double rtol=1e-5f, double atol=1e-8f)
Determine whether two tensors are element-wise equal within a tolerance.
constexpr auto operator<(First &&lhs, Second &&rhs)
Definition: comparisons.hpp:169
Definition: buffer.hpp:41
std::decay_t< T > Reference
Definition: traits.hpp:28
Definition: comparisons.hpp:134
void forward(Tensor const &, Tensor const &, Tensor &) const
Definition: comparisons.hpp:146
void forward(Tensor const &, Tensor const &, Tensor &) const
Definition: comparisons.hpp:142
void forward(Tensor const &, Tensor const &, Tensor &) const
Definition: comparisons.hpp:154
void forward(Tensor const &, Tensor const &, Tensor &) const
Definition: comparisons.hpp:150
void forward(Tensor const &, Tensor const &, Tensor &) const
Definition: comparisons.hpp:138
void forward(Tensor const &, Tensor const &, Tensor &) const