Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated files
build/
build-*/
generated/

# Prerequisites
Expand Down
1 change: 1 addition & 0 deletions include/infini/rt.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef INFINI_RT_PUBLIC_H_
#define INFINI_RT_PUBLIC_H_

#include <infini/rt/c_api.h>
#include <infini/rt/generated.h>

#endif
78 changes: 78 additions & 0 deletions include/infini/rt/c_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#ifndef INFINI_RT_C_API_H_
#define INFINI_RT_C_API_H_

#if defined(_WIN32)
#define INFINI_RT_EXPORT __declspec(dllexport)
#elif defined(__GNUC__) && \
((__GNUC__ >= 4) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 3))
#define INFINI_RT_EXPORT __attribute__((visibility("default")))
#else
#define INFINI_RT_EXPORT
#endif

#ifdef __cplusplus
#define INFINI_RT_EXTERN_C extern "C"
#else
#define INFINI_RT_EXTERN_C
#endif

typedef enum {
INFINI_RT_STATUS_SUCCESS = 0,
INFINI_RT_STATUS_INVALID_ARGUMENT = 1,
INFINI_RT_STATUS_UNSUPPORTED_DEVICE = 2,
INFINI_RT_STATUS_RUNTIME_ERROR = 3,
} infiniRtStatus_t;

typedef enum {
INFINI_RT_DEVICE_CPU = 0,
INFINI_RT_DEVICE_NVIDIA = 1,
INFINI_RT_DEVICE_CAMBRICON = 2,
INFINI_RT_DEVICE_ASCEND = 3,
INFINI_RT_DEVICE_METAX = 4,
INFINI_RT_DEVICE_MOORE = 5,
INFINI_RT_DEVICE_ILUVATAR = 6,
INFINI_RT_DEVICE_KUNLUN = 7,
INFINI_RT_DEVICE_HYGON = 8,
INFINI_RT_DEVICE_QY = 9,
} infiniRtDeviceType_t;

typedef struct {
infiniRtDeviceType_t type;
int index;
} infiniRtDevice_t;

typedef enum {
INFINI_RT_STREAM_CAPTURE_MODE_GLOBAL = 0,
INFINI_RT_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1,
INFINI_RT_STREAM_CAPTURE_MODE_RELAXED = 2,
} infiniRtStreamCaptureMode_t;

typedef void* infiniRtStream_t;
typedef void* infiniRtGraph_t;
typedef void* infiniRtGraphExec_t;

INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t infiniRtStreamWrap(
infiniRtDevice_t device, void* native_stream, infiniRtStream_t* stream);

INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t
infiniRtStreamDestroy(infiniRtStream_t stream);

INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t infiniRtStreamBeginCapture(
infiniRtStream_t stream, infiniRtStreamCaptureMode_t mode);

INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t
infiniRtStreamEndCapture(infiniRtStream_t stream, infiniRtGraph_t* graph);

INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t
infiniRtGraphDestroy(infiniRtGraph_t graph);

INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t infiniRtGraphInstantiate(
infiniRtGraphExec_t* graph_exec, infiniRtGraph_t graph);

INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t
infiniRtGraphExecDestroy(infiniRtGraphExec_t graph_exec);

INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t
infiniRtGraphLaunch(infiniRtGraphExec_t graph_exec, infiniRtStream_t stream);

#endif
131 changes: 129 additions & 2 deletions scripts/generate_public_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _rewrite_detail_include(match):


_DETAIL_INCLUDE_PATTERN = re.compile(
r'#include "((?:common|native)/[^"]+|data_type\.h|device\.h|dispatcher\.h|hash\.h|runtime\.h|tensor_view\.h)"'
r'#include "((?:common|native)/[^"]+|data_type\.h|device\.h|dispatcher\.h|graph\.h|hash\.h|runtime\.h|tensor_view\.h)"'
)


Expand Down Expand Up @@ -133,6 +133,7 @@ def _write_detail_headers(include_root, source_root, devices):
"data_type.h",
"device.h",
"dispatcher.h",
"graph.h",
"hash.h",
"runtime.h",
"tensor_view.h",
Expand All @@ -158,6 +159,7 @@ def _write_generated_header(include_root, devices):
includes = [
f"#include {_detail_include('data_type.h')}",
f"#include {_detail_include('device.h')}",
f"#include {_detail_include('graph.h')}",
f"#include {_detail_include('hash.h')}",
f"#include {_detail_include('runtime.h')}",
f"#include {_detail_include('tensor_view.h')}",
Expand Down Expand Up @@ -210,7 +212,11 @@ def _parse_runtime_functions(runtime_header):
_Function(
return_type,
name,
tuple(_parse_param(param) for param in params.split(", ") if param),
tuple(
_parse_param(param)
for param in re.split(r",\s*", params.strip())
if param
),
)
for return_type, name, params in re.findall(
r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE
Expand Down Expand Up @@ -239,6 +245,8 @@ def _selector(function):
return f"{param.name}.type()"
if param.type == "Device::Type":
return param.name
if param.type in {"Stream", "Graph", "GraphExec"}:
return f"{param.name}.device_type()"

return "current_device.type()"

Expand All @@ -250,6 +258,13 @@ def _runtime_arg(param):
return None
if param.type == "MemcpyKind":
return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})"
if param.type == "StreamCaptureMode":
return f"RuntimeStreamCaptureMode<__DEVICE_TYPE__>({param.name})"
if param.type in {"Stream", "Graph", "GraphExec"}:
return (
f"static_cast<typename Runtime<__DEVICE_TYPE__>::{param.type}>"
f"({param.name}.raw())"
)

return param.name

Expand All @@ -264,6 +279,9 @@ def _preconditions(function):
required_pointer_names = {
"GetDevice": {"device"},
"GetDeviceCount": {"count"},
"StreamCreate": {"stream"},
"StreamEndCapture": {"graph"},
"GraphInstantiate": {"graph_exec"},
}
checks = []
for param in function.params:
Expand All @@ -290,6 +308,27 @@ def _runtime_call(function):
return f"Runtime<__DEVICE_TYPE__>::{function.name}()"


def _write_stream_create(function, devices):
stream_param = function.params[0].name
cases = _dispatch_cases(
devices,
f""" typename Runtime<__DEVICE_TYPE__>::Stream raw_stream = {{}};
CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::StreamCreate(&raw_stream); }});
*{stream_param} = Stream{{__DEVICE_TYPE__, static_cast<void*>(raw_stream)}};""",
)

return f"""void StreamCreate(Stream* {stream_param}) {{
assert({stream_param} != nullptr);

switch (current_device.type()) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
}}
}}
"""


def _write_get_device(function, devices):
device_param = function.params[0].name
cases = _dispatch_cases(
Expand All @@ -312,9 +351,81 @@ def _write_get_device(function, devices):
"""


def _write_stream_end_capture(function, devices):
stream_param = function.params[0].name
graph_param = function.params[1].name
cases = _dispatch_cases(
devices,
f""" typename Runtime<__DEVICE_TYPE__>::Graph raw_graph = {{}};
CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::StreamEndCapture(static_cast<typename Runtime<__DEVICE_TYPE__>::Stream>({stream_param}.raw()), &raw_graph); }});
*{graph_param} = Graph{{__DEVICE_TYPE__, static_cast<void*>(raw_graph)}};""",
)

return f"""void StreamEndCapture(Stream {stream_param}, Graph* {graph_param}) {{
assert({graph_param} != nullptr);

switch ({stream_param}.device_type()) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
}}
}}
"""


def _write_graph_instantiate(function, devices):
graph_exec_param = function.params[0].name
graph_param = function.params[1].name
cases = _dispatch_cases(
devices,
f""" typename Runtime<__DEVICE_TYPE__>::GraphExec raw_graph_exec = {{}};
CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GraphInstantiate(&raw_graph_exec, static_cast<typename Runtime<__DEVICE_TYPE__>::Graph>({graph_param}.raw())); }});
*{graph_exec_param} = GraphExec{{__DEVICE_TYPE__, static_cast<void*>(raw_graph_exec)}};""",
)

return f"""void GraphInstantiate(GraphExec* {graph_exec_param}, Graph {graph_param}) {{
assert({graph_exec_param} != nullptr);

switch ({graph_param}.device_type()) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
}}
}}
"""


def _write_graph_launch(function, devices):
graph_exec_param = function.params[0].name
stream_param = function.params[1].name
cases = _dispatch_cases(
devices,
f""" CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GraphLaunch(static_cast<typename Runtime<__DEVICE_TYPE__>::GraphExec>({graph_exec_param}.raw()), static_cast<typename Runtime<__DEVICE_TYPE__>::Stream>({stream_param}.raw())); }});""",
)

return f"""void GraphLaunch(GraphExec {graph_exec_param}, Stream {stream_param}) {{
assert({graph_exec_param}.device_type() == {stream_param}.device_type());

switch ({graph_exec_param}.device_type()) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
}}
}}
"""


def _write_dispatch_function(function, devices):
if function.name == "GetDevice":
return _write_get_device(function, devices)
if function.name == "StreamCreate":
return _write_stream_create(function, devices)
if function.name == "StreamEndCapture":
return _write_stream_end_capture(function, devices)
if function.name == "GraphInstantiate":
return _write_graph_instantiate(function, devices)
if function.name == "GraphLaunch":
return _write_graph_launch(function, devices)

cases = _dispatch_cases(
devices,
Expand Down Expand Up @@ -390,6 +501,22 @@ def _write_runtime_dispatch(source_path, runtime_header, devices):
return Runtime<kDev>::MemcpyHostToHost;
}}

template <Device::Type kDev>
auto RuntimeStreamCaptureMode(StreamCaptureMode mode) {{
switch (mode) {{
case StreamCaptureMode::kGlobal:
return Runtime<kDev>::StreamCaptureModeGlobal;
case StreamCaptureMode::kThreadLocal:
return Runtime<kDev>::StreamCaptureModeThreadLocal;
case StreamCaptureMode::kRelaxed:
return Runtime<kDev>::StreamCaptureModeRelaxed;
}}

assert(false && "unsupported stream capture mode");
std::abort();
return Runtime<kDev>::StreamCaptureModeGlobal;
}}

}} // namespace

{dispatch_functions}
Expand Down
Loading
Loading