Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
82ad598
Add torch.Tensor fast path for StridedMemoryView via AOTI tensor bridge
leofang Apr 9, 2026
f8f8d8c
Clean up tensor bridge: remove unused AOTI decls, lazy dtype, drop em…
leofang Apr 9, 2026
af06e9b
Move torch tensor fast path into each from_* classmethod
leofang Apr 9, 2026
44be580
Add stream ordering for torch tensor bridge
leofang Apr 9, 2026
6e6b8a6
Extract reusable sync_torch_stream and apply to CAI path
leofang Apr 9, 2026
85caaaf
Nits: add check_aoti helper, size_t itemsize, 2D sliced test
leofang Apr 9, 2026
9fad471
Revert itemsize to int, memoize int(stream_ptr)
leofang Apr 9, 2026
cc4558a
Use except?-1 instead of except* for check_aoti
leofang Apr 9, 2026
5f49e7a
Require PyTorch >= 2.3 for tensor bridge, move imports to module level
leofang Apr 9, 2026
b98fe71
Add tensor bridge entry to 1.0.0 release notes
leofang Apr 9, 2026
30ba7d5
Update speedup range in release notes to match benchmarks
leofang Apr 9, 2026
0f57646
Document THPVariable layout change across PyTorch versions
leofang Apr 9, 2026
74798e7
Cache type check in _is_torch_tensor for ~20% speedup
leofang Apr 9, 2026
00b8ec9
Add upper bound to torch version check (cap at 2.11)
leofang Apr 10, 2026
0c31df1
Update module docstring to document both THPVariable layouts
leofang Apr 10, 2026
8c20237
Use except?-1 for sync_torch_stream instead of except*
leofang Apr 10, 2026
8c019b9
Fix linter errors
leofang Apr 10, 2026
6682646
Fix pyobj_to_aten_handle for PyTorch 2.3–2.9 MaybeOwned layout
leofang Apr 14, 2026
0b7245b
Consolidate torch tensor bridge tests into TestViewCPU/TestViewGPU
leofang Apr 14, 2026
626736a
Extract _arr_size helper for torch/numpy size compatibility
leofang Apr 14, 2026
d1d3841
Fix ruff formatting in test_utils.py
leofang Apr 14, 2026
b9d80e7
Add readonly comment and fix vendored header license to BSD-3-Clause
leofang Apr 14, 2026
7d46123
Merge bfloat16 test into test_torch_tensor_bridge_dtypes parametrization
leofang Apr 14, 2026
c7331a9
Fix SPDX linter: use PyTorch copyright in vendored header
leofang Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .spdx-ignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ cuda_bindings/examples/*

# Vendored
cuda_core/cuda/core/_include/dlpack.h
cuda_core/cuda/core/_include/aoti_shim.h

qa/ctk-next.drawio.svg
94 changes: 94 additions & 0 deletions cuda_core/cuda/core/_include/aoti_shim.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Vendored subset of PyTorch's AOT Inductor (AOTI) stable C ABI.
* Original: torch/csrc/inductor/aoti_torch/c/shim.h
*
* These are declarations only -- no definitions are provided. The actual
* symbols are exported by libtorch (loaded via torch._C with RTLD_GLOBAL)
* and resolved at runtime by the dynamic linker. This means PyTorch is
* NOT required at compile time.
*
* From PyTorch:
*
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
* Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
* Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
* Copyright (c) 2011-2013 NYU (Clement Farabet)
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
*
* SPDX-License-Identifier: BSD-3-Clause
* See https://github.com/pytorch/pytorch/blob/main/LICENSE
*/

#ifndef CUDA_CORE_AOTI_SHIM_H
#define CUDA_CORE_AOTI_SHIM_H

#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

typedef int32_t AOTITorchError;

/* Opaque tensor handle -- corresponds to at::Tensor on the C++ side. */
struct AtenTensorOpaque;
typedef struct AtenTensorOpaque* AtenTensorHandle;

/* ---- tensor metadata --------------------------------------------------- */

AOTITorchError aoti_torch_get_data_ptr(
AtenTensorHandle tensor, void** ret_data_ptr);

AOTITorchError aoti_torch_get_dim(
AtenTensorHandle tensor, int64_t* ret_dim);

AOTITorchError aoti_torch_get_sizes(
AtenTensorHandle tensor, int64_t** ret_sizes);

AOTITorchError aoti_torch_get_strides(
AtenTensorHandle tensor, int64_t** ret_strides);

/* ---- dtype ------------------------------------------------------------- */

AOTITorchError aoti_torch_get_dtype(
AtenTensorHandle tensor, int32_t* ret_dtype);

int32_t aoti_torch_dtype_float16(void);
int32_t aoti_torch_dtype_float32(void);
int32_t aoti_torch_dtype_float64(void);
int32_t aoti_torch_dtype_bfloat16(void);
int32_t aoti_torch_dtype_uint8(void);
int32_t aoti_torch_dtype_int8(void);
int32_t aoti_torch_dtype_int16(void);
int32_t aoti_torch_dtype_int32(void);
int32_t aoti_torch_dtype_int64(void);
int32_t aoti_torch_dtype_bool(void);
int32_t aoti_torch_dtype_complex32(void);
int32_t aoti_torch_dtype_complex64(void);
int32_t aoti_torch_dtype_complex128(void);

/* ---- device ------------------------------------------------------------ */

AOTITorchError aoti_torch_get_device_type(
AtenTensorHandle tensor, int32_t* ret_device_type);

AOTITorchError aoti_torch_get_device_index(
AtenTensorHandle tensor, int32_t* ret_device_index);

int32_t aoti_torch_device_type_cpu(void);
int32_t aoti_torch_device_type_cuda(void);

/* ---- stream -------------------------------------------------------------- */

AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index, void** ret_stream);

#ifdef __cplusplus
} /* extern "C" */
#endif

#endif /* CUDA_CORE_AOTI_SHIM_H */
94 changes: 94 additions & 0 deletions cuda_core/cuda/core/_memoryview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ from libc.stdint cimport intptr_t
from cuda.core._layout cimport _StridedLayout, get_strides_ptr
from cuda.core._stream import Stream

import ctypes
import functools
import sys
import warnings

import numpy
Expand All @@ -29,6 +31,73 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
from cuda.core._memory import Buffer


# ---------------------------------------------------------------------------
# Lazy tensor bridge (avoids loading _tensor_bridge.so until torch is used)
# ---------------------------------------------------------------------------

cdef object _tensor_bridge = None
# Cache: type(obj) -> True/False for the torch tensor check.
# Once a type is seen, we never re-check.
cdef dict _torch_type_cache = {}
# Tri-state: None = not checked, True/False = result of version check
cdef object _torch_version_ok = None

cdef inline bint _torch_version_check():
"""Return True if 2.3 <= torch <= 2.11 (known AOTI ABI range). Memoized.

Lower bound: AOTI functions we use were introduced in PyTorch 2.3.
Upper bound: the ``pyobj_to_aten_handle`` trick relies on the
THPVariable struct layout (PyObject_HEAD followed by at::Tensor cdata)
and the identity ``AtenTensorHandle == at::Tensor*``. Both are
undocumented internals that could change in a future PyTorch version.
We cap at the latest version we have tested against; unknown versions
fall back to the standard DLPack/CAI paths. Bump the upper bound
after verifying a new PyTorch release.
"""
global _torch_version_ok
if _torch_version_ok is not None:
return <bint>_torch_version_ok
torch = sys.modules.get("torch")
if torch is None:
_torch_version_ok = False
return False
try:
major, minor = int(torch.__version__.split(".")[0]), \
int(torch.__version__.split(".")[1])
_torch_version_ok = (2, 3) <= (major, minor) <= (2, 11)
except (ValueError, IndexError):
_torch_version_ok = False
return <bint>_torch_version_ok


cdef inline bint _is_torch_tensor(object obj):
cdef type tp = type(obj)
cdef object cached = _torch_type_cache.get(tp)
if cached is not None:
return <bint>cached
cdef str mod = tp.__module__ or ""
cdef bint result = mod.startswith("torch") and hasattr(obj, "data_ptr") \
and _torch_version_check()
_torch_type_cache[tp] = result
return result


cdef object _get_tensor_bridge():
"""Bootstrap AOTI symbols, then import _tensor_bridge on first use."""
global _tensor_bridge
if _tensor_bridge is not None:
return _tensor_bridge
torch_C = sys.modules.get("torch._C")
if torch_C is None:
raise RuntimeError(
"torch._C is not loaded; cannot initialise the tensor bridge. "
"Make sure PyTorch is imported before passing a torch.Tensor.")
ctypes.CDLL(torch_C.__file__, mode=ctypes.RTLD_GLOBAL)
from cuda.core import _tensor_bridge as tb
_tensor_bridge = tb
return _tensor_bridge


try:
from ml_dtypes import bfloat16
except ImportError:
Expand Down Expand Up @@ -150,6 +219,9 @@ cdef class StridedMemoryView:
Stream pointer for synchronization. If ``None``, no synchronization is performed.
"""
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
if _is_torch_tensor(obj):
_get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf)
return buf
view_as_dlpack(obj, stream_ptr, buf)
return buf

Expand All @@ -165,6 +237,9 @@ cdef class StridedMemoryView:
Stream pointer for synchronization. If ``None``, no synchronization is performed.
"""
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
if _is_torch_tensor(obj):
_get_tensor_bridge().view_as_torch_tensor(obj, stream_ptr, buf)
return buf
view_as_cai(obj, stream_ptr, buf)
return buf

Expand All @@ -178,6 +253,9 @@ cdef class StridedMemoryView:
An object implementing the `__array_interface__ <https://numpy.org/doc/stable/reference/arrays.interface.html>`_ protocol (e.g., a numpy array).
"""
cdef StridedMemoryView buf = StridedMemoryView.__new__(cls)
if _is_torch_tensor(obj):
_get_tensor_bridge().view_as_torch_tensor(obj, None, buf)
return buf
view_as_array_interface(obj, buf)
return buf

Expand All @@ -187,6 +265,8 @@ cdef class StridedMemoryView:

Tries `DLPack <https://dmlc.github.io/dlpack/latest/>`_ first, then falls back to
`__cuda_array_interface__ <https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html>`_.
``torch.Tensor`` objects are transparently handled via a fast AOTI path
regardless of which protocol is selected.

Parameters
----------
Expand Down Expand Up @@ -480,6 +560,10 @@ cdef class StridedMemoryView:
if self._dtype is None:
if self.dl_tensor != NULL:
self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype)
elif isinstance(self.metadata, int):
# AOTI dtype code stored by the torch tensor bridge
self._dtype = _get_tensor_bridge().resolve_aoti_dtype(
self.metadata)
elif self.metadata is not None:
self._dtype = _typestr2dtype(self.metadata["typestr"])
return self._dtype
Expand Down Expand Up @@ -1122,6 +1206,16 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
as_cu(h_event), <cydriver.CUstream>producer_s))
HANDLE_RETURN(cydriver.cuStreamWaitEvent(
<cydriver.CUstream>consumer_s, as_cu(h_event), 0))
elif _is_torch_tensor(obj):
# PyTorch's __cuda_array_interface__ reports version 2 and
# omits the "stream" field, so the standard CAI sync path
# above is a no-op for torch tensors. This is unsafe: the
# consumer has no guarantee that the producer's work is
# visible. We fix this by querying PyTorch's current CUDA
# stream via the AOTI stable C ABI and performing the same
# event-based stream ordering.
_get_tensor_bridge().sync_torch_stream(
buf.device_id, <intptr_t>(stream_ptr))

return buf

Expand Down
Loading
Loading