21#include "runtime/tensor.h"
22#include "runtime/graph.h"
23#include "runtime/streams.h"
24#include "runtime/resources.h"
31static inline tensor_t* get_tensor(uintptr_t
id) {
32 return reinterpret_cast<node_t*
>(id)->target;
35template <
class H,
class D>
43 , device_fn(device) {}
48 if (std::holds_alternative<Host>(output.
environment())) {
49 tensor_t* src = get_tensor(input.
node()->
id);
50 tensor_t* dst = get_tensor(output.
node()->
id);
51 auto status = host_fn(src, dst);
52 if (status != SUCCESS) {
53 throw std::runtime_error(
"Unsupported dtype");
59 device_t dvc{resource.
id(), resource.
blocking() ? SYNC : ASYNC};
60 tensor_t* src = get_tensor(input.
node()->
id);
61 tensor_t* dst = get_tensor(output.
node()->
id);
62 stream_t stream = pop_stream(&dvc);
63 auto status = device_fn(src, dst, stream);
64 put_stream(&dvc, stream);
65 if (status != SUCCESS) {
66 throw std::runtime_error(
"Unsupported dtype");
72 tensor_t* src0 = get_tensor(first.
node()->
id);
73 tensor_t* src1 = get_tensor(second.
node()->
id);
74 environment_t environment;
75 auto status = resolve_environment(&src0->environment, &src1->environment, &environment);
76 if(status != SUCCESS) {
77 throw std::runtime_error(
"Environment issue!");
79 switch (environment.environment) {
81 host_t resource = environment.resource.host;
83 tensor_t* dst = get_tensor(output.
node()->
id);
84 auto status = host_fn(src0, src1, dst);
85 if(status != SUCCESS) {
86 throw std::runtime_error(
"Unsupported dtype");
92 device_t dvc = environment.resource.device;
94 stream_t stream = pop_stream(&dvc);
95 tensor_t* dst = get_tensor(output.
node()->
id);
96 auto status = device_fn(src0, src1, dst, stream);
97 put_stream(&dvc, stream);
98 if(status != SUCCESS) {
99 throw std::runtime_error(
"Unsupported dtype");
Definition: callback.hpp:36
void operator()(Tensor const &input, Tensor &output) const
Definition: callback.hpp:45
void operator()(Tensor const &first, Tensor const &second, Tensor &output)
Definition: callback.hpp:71
Callback(H host, D device)
Definition: callback.hpp:41
Device memory domain.
Definition: resources.hpp:156
bool blocking() const
Definition: resources.hpp:189
int id() const noexcept
Device identifier.
Definition: resources.hpp:185
Host memory domain.
Definition: resources.hpp:60
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Environment const & environment() const
Returns a reference to the environment variant used to allocate this tensor's buffer.
Definition: tensor.hpp:260
void initialize(Environment environment=Host{}) const
Allocates the memory buffer for the tensor.
Node * node() const
Definition: tensor.hpp:1122
Definition: buffer.hpp:41
uintptr_t id
Definition: graph.hpp:30