Tannic
A C++ Tensor Library
Loading...
Searching...
No Matches
callback.hpp
Go to the documentation of this file.
1// Copyright 2025 Eric Hermosis
2//
3// This file is part of the Tannic Tensor Library.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17
18#ifndef CALLBACK_HPP
19#define CALLBACK_HPP
20
21#include "runtime/tensor.h"
22#include "runtime/graph.h"
23#include "runtime/streams.h"
24#include "runtime/resources.h"
25#include "tensor.hpp"
26
27namespace tannic {
28
29// WARNING: THIS FILE IS UNDER ACTIVE DEVELOPMENT!.
30
31static inline tensor_t* get_tensor(uintptr_t id) {
32 return reinterpret_cast<node_t*>(id)->target;
33}
34
35template <class H, class D>
36class Callback {
37 H host_fn;
38 D device_fn;
39
40public:
41 Callback(H host, D device)
42 : host_fn(host)
43 , device_fn(device) {}
44
45 void operator()(Tensor const& input, Tensor& output) const {
46 output.initialize(input.environment());
47
48 if (std::holds_alternative<Host>(output.environment())) {
49 tensor_t* src = get_tensor(input.node()->id);
50 tensor_t* dst = get_tensor(output.node()->id);
51 auto status = host_fn(src, dst);
52 if (status != SUCCESS) {
53 throw std::runtime_error("Unsupported dtype");
54 }
55 }
56
57 else {
58 Device const& resource = std::get<Device>(output.environment());
59 device_t dvc{resource.id(), resource.blocking() ? SYNC : ASYNC};
60 tensor_t* src = get_tensor(input.node()->id);
61 tensor_t* dst = get_tensor(output.node()->id);
62 stream_t stream = pop_stream(&dvc);
63 auto status = device_fn(src, dst, stream);
64 put_stream(&dvc, stream);
65 if (status != SUCCESS) {
66 throw std::runtime_error("Unsupported dtype");
67 }
68 }
69 }
70
71 void operator()(Tensor const& first, Tensor const& second, Tensor& output){
72 tensor_t* src0 = get_tensor(first.node()->id);
73 tensor_t* src1 = get_tensor(second.node()->id);
74 environment_t environment;
75 auto status = resolve_environment(&src0->environment, &src1->environment, &environment);
76 if(status != SUCCESS) {
77 throw std::runtime_error("Environment issue!");
78 }
79 switch (environment.environment) {
80 case HOST: {
81 host_t resource = environment.resource.host;
82 output.initialize(Host());
83 tensor_t* dst = get_tensor(output.node()->id);
84 auto status = host_fn(src0, src1, dst);
85 if(status != SUCCESS) {
86 throw std::runtime_error("Unsupported dtype");
87 }
88 break;
89 }
90
91 case DEVICE: {
92 device_t dvc = environment.resource.device;
93 output.initialize(Device(dvc.id));
94 stream_t stream = pop_stream(&dvc);
95 tensor_t* dst = get_tensor(output.node()->id);
96 auto status = device_fn(src0, src1, dst, stream);
97 put_stream(&dvc, stream);
98 if(status != SUCCESS) {
99 throw std::runtime_error("Unsupported dtype");
100 }
101 break;
102 }
103
104 default:
105 break;
106 }
107 }
108
109};
110
111} // namespace tannic
112
113#endif
Definition: callback.hpp:36
void operator()(Tensor const &input, Tensor &output) const
Definition: callback.hpp:45
void operator()(Tensor const &first, Tensor const &second, Tensor &output)
Definition: callback.hpp:71
Callback(H host, D device)
Definition: callback.hpp:41
Device memory domain.
Definition: resources.hpp:156
bool blocking() const
Definition: resources.hpp:189
int id() const noexcept
Device identifier.
Definition: resources.hpp:185
Host memory domain.
Definition: resources.hpp:60
A multidimensional, strided tensor data structure.
Definition: tensor.hpp:99
Environment const & environment() const
Returns a reference to the environment variant used to allocate this tensor's buffer.
Definition: tensor.hpp:260
void initialize(Environment environment=Host{}) const
Allocates the memory buffer for the tensor.
Node * node() const
Definition: tensor.hpp:1122
Definition: buffer.hpp:41
uintptr_t id
Definition: graph.hpp:30