44#include < stdio.h>
55#include < string>
66#include < tuple>
7+ #include < unistd.h>
78#include < unordered_map>
89
910#include " absl/log/globals.h"
@@ -306,10 +307,10 @@ FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND);
306307// ExlaBuffer Functions
307308
308309std::variant<std::tuple<fine::Atom, uint64_t , uint64_t >,
309- std::tuple<fine::Atom, std::string, uint64_t , uint64_t >,
310310 std::tuple<fine::Atom, std::string, uint64_t >>
311311get_buffer_device_pointer (ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
312- fine::Term buffer_term, fine::Atom pointer_kind) {
312+ fine::Term buffer_term, fine::Atom pointer_kind,
313+ int64_t shm_permissions) {
313314 auto buffer = decode_exla_buffer (env, buffer_term);
314315
315316 uint64_t device_size = unwrap (buffer->GetOnDeviceSizeInBytes ());
@@ -321,22 +322,38 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
321322
322323 if (pointer_kind == " host_ipc" ) {
323324 auto handle_name =
324- " exla:ipc:" + std::to_string (device_size) + " :" + std::to_string (ptr);
325- auto fd = get_ipc_handle (handle_name.c_str (), device_size);
325+ " exla:ipc:" + std::to_string (getpid ()) + " :" + std::to_string (ptr);
326+ auto fd = get_ipc_handle (handle_name.c_str (), device_size,
327+ static_cast <mode_t >(shm_permissions));
326328
327329 if (fd == -1 ) {
328330 throw std::runtime_error (" unable to get IPC handle" );
329331 }
330332
331- auto ipc_ptr = open_ipc_handle (fd, device_size);
333+ auto ipc_ptr = open_ipc_handle (fd, device_size, /* writable= */ 1 );
332334 if (ipc_ptr == nullptr ) {
333335 throw std::runtime_error (" unable to open IPC handle" );
334336 }
335337
336338 memcpy (ipc_ptr, reinterpret_cast <void *>(ptr), device_size);
337339
338- return std::make_tuple (pointer_kind, handle_name, static_cast <uint64_t >(fd),
339- device_size);
340+ // Repoint the original buffer at the shm mapping so both the exporter
341+ // and any importers share the same physical pages. This also frees
342+ // the old XLA-managed memory, avoiding double memory usage.
343+ auto shape = unwrap (buffer->buffer ()->logical_on_device_shape ());
344+ auto device = unwrap (client->client ()->LookupDevice (
345+ xla::PjRtGlobalDeviceId (buffer->device_id ())));
346+ auto memory_space = unwrap (device->default_memory_space ());
347+
348+ auto on_delete = [fd, ipc_ptr, device_size, handle_name]() {
349+ close_ipc_handle (fd, ipc_ptr, handle_name.c_str (), device_size);
350+ };
351+
352+ auto new_pjrt_buf = unwrap (client->client ()->CreateViewOfDeviceBuffer (
353+ ipc_ptr, shape, memory_space, on_delete));
354+ buffer->ReplaceBuffer (std::move (new_pjrt_buf));
355+
356+ return std::make_tuple (pointer_kind, handle_name, device_size);
340357 }
341358
342359 if (pointer_kind == " cuda_ipc" ) {
@@ -369,17 +386,20 @@ fine::ResourcePtr<ExlaBuffer> create_buffer_from_device_pointer(
369386 }
370387 ptr = maybe_pointer.value ();
371388 } else if (pointer_kind == " host_ipc" ) {
372- auto tuple =
373- fine::decode<std::tuple<uint64_t , std::string>>(env, pointer_data);
374- auto fd = std::get<0 >(tuple);
375- auto memname = std::get<1 >(tuple);
389+ auto memname = fine::decode<std::string>(env, pointer_data);
376390 auto device_size = xla::ShapeUtil::ByteSizeOf (shape);
377- ptr = open_ipc_handle (fd, device_size);
391+ int writable = 0 ;
392+ auto fd = open_existing_ipc_handle (memname.c_str (), &writable);
393+ if (fd == -1 ) {
394+ throw std::runtime_error (" unable to get fd for IPC handle" );
395+ }
396+ ptr = open_ipc_handle (fd, device_size, writable);
378397 if (ptr == nullptr ) {
398+ close (fd);
379399 throw std::runtime_error (" unable to get pointer for IPC handle" );
380400 }
381- on_delete_callback = [fd, memname, ptr, device_size]() {
382- close_ipc_handle (fd, ptr, memname. c_str () , device_size);
401+ on_delete_callback = [fd, ptr, device_size]() {
402+ close_imported_ipc_handle (fd, ptr, device_size);
383403 };
384404 } else if (pointer_kind == " local" ) {
385405 auto ptr_int = fine::decode<int64_t >(env, pointer_data);
@@ -677,6 +697,22 @@ fine::Ok<> start_log_sink(ErlNifEnv *env, ErlNifPid logger_pid) {
677697
678698FINE_NIF (start_log_sink, 0 );
679699
700+ // Test-only NIFs — excluded from production builds.
701+ #ifndef EXLA_PROD
702+
703+ // Writes `data` into the memory at `address + offset`. Intentionally unsafe:
704+ // no bounds checking. The caller is responsible for ensuring the pointer is
705+ // valid and the region is large enough.
706+ fine::Ok<> write_to_pointer (ErlNifEnv *env, uint64_t address,
707+ ErlNifBinary data, uint64_t offset) {
708+ uint8_t *ptr = reinterpret_cast <uint8_t *>(address);
709+ std::memcpy (ptr + offset, data.data , data.size );
710+ return fine::Ok ();
711+ }
712+ FINE_NIF (write_to_pointer, 0 );
713+
714+ #endif // EXLA_PROD
715+
680716} // namespace exla
681717
682718FINE_INIT (" Elixir.EXLA.NIF" );
0 commit comments