Finds the indices of maximum values along an axis.
More...
#include <reductions.hpp>
Finds the indices of maximum values along an axis.
Output dtype is always int64
. The reduced axis is removed by default (keepdim=false
).
Example:
{1, 5, 9}};
Tensor Y = argmax(X, 0);
std::cout << Y << std::endl;
Tensor Z = argmax(X, 1,
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:105
◆ forward()
void tannic::expression::Argmax::forward |
( |
Tensor const & |
input, |
|
|
Tensor & |
output |
|
) |
| const |
◆ reduce() [1/2]
constexpr Shape tannic::expression::Argmax::reduce |
( |
Shape const & |
shape | ) |
const |
|
inlineconstexpr |
◆ reduce() [2/2]
constexpr type tannic::expression::Argmax::reduce |
( |
type |
dtype | ) |
const |
|
inlineconstexpr |
◆ axis
int tannic::expression::Argmax::axis |
◆ keepdim
bool tannic::expression::Argmax::keepdim |
The documentation for this struct was generated from the following file: