Skip to content

Commit 221ecbc

Browse files
blasphemetheuspolvalenteclaude
authored
feat(exla): configurable shm permissions for IPC buffers, default 0400 (#1732)
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d528827 commit 221ecbc

12 files changed

Lines changed: 361 additions & 32 deletions

File tree

exla/Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ else
4141
CFLAGS += -O3
4242
endif
4343

44+
ifeq ($(MIX_ENV),prod)
45+
CFLAGS += -DEXLA_PROD
46+
endif
47+
4448
NVCC = $(CXX)
4549
NVCCFLAGS = $(CFLAGS)
4650
LDFLAGS += -L$(XLA_EXTENSION_LIB) -lxla_extension -shared -fvisibility=hidden

exla/c_src/exla/exla.cc

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <stdio.h>
55
#include <string>
66
#include <tuple>
7+
#include <unistd.h>
78
#include <unordered_map>
89

910
#include "absl/log/globals.h"
@@ -306,10 +307,10 @@ FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND);
306307
// ExlaBuffer Functions
307308

308309
std::variant<std::tuple<fine::Atom, uint64_t, uint64_t>,
309-
std::tuple<fine::Atom, std::string, uint64_t, uint64_t>,
310310
std::tuple<fine::Atom, std::string, uint64_t>>
311311
get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
312-
fine::Term buffer_term, fine::Atom pointer_kind) {
312+
fine::Term buffer_term, fine::Atom pointer_kind,
313+
int64_t shm_permissions) {
313314
auto buffer = decode_exla_buffer(env, buffer_term);
314315

315316
uint64_t device_size = unwrap(buffer->GetOnDeviceSizeInBytes());
@@ -321,22 +322,38 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
321322

322323
if (pointer_kind == "host_ipc") {
323324
auto handle_name =
324-
"exla:ipc:" + std::to_string(device_size) + ":" + std::to_string(ptr);
325-
auto fd = get_ipc_handle(handle_name.c_str(), device_size);
325+
"exla:ipc:" + std::to_string(getpid()) + ":" + std::to_string(ptr);
326+
auto fd = get_ipc_handle(handle_name.c_str(), device_size,
327+
static_cast<mode_t>(shm_permissions));
326328

327329
if (fd == -1) {
328330
throw std::runtime_error("unable to get IPC handle");
329331
}
330332

331-
auto ipc_ptr = open_ipc_handle(fd, device_size);
333+
auto ipc_ptr = open_ipc_handle(fd, device_size, /*writable=*/1);
332334
if (ipc_ptr == nullptr) {
333335
throw std::runtime_error("unable to open IPC handle");
334336
}
335337

336338
memcpy(ipc_ptr, reinterpret_cast<void *>(ptr), device_size);
337339

338-
return std::make_tuple(pointer_kind, handle_name, static_cast<uint64_t>(fd),
339-
device_size);
340+
// Repoint the original buffer at the shm mapping so both the exporter
341+
// and any importers share the same physical pages. This also frees
342+
// the old XLA-managed memory, avoiding double memory usage.
343+
auto shape = unwrap(buffer->buffer()->logical_on_device_shape());
344+
auto device = unwrap(client->client()->LookupDevice(
345+
xla::PjRtGlobalDeviceId(buffer->device_id())));
346+
auto memory_space = unwrap(device->default_memory_space());
347+
348+
auto on_delete = [fd, ipc_ptr, device_size, handle_name]() {
349+
close_ipc_handle(fd, ipc_ptr, handle_name.c_str(), device_size);
350+
};
351+
352+
auto new_pjrt_buf = unwrap(client->client()->CreateViewOfDeviceBuffer(
353+
ipc_ptr, shape, memory_space, on_delete));
354+
buffer->ReplaceBuffer(std::move(new_pjrt_buf));
355+
356+
return std::make_tuple(pointer_kind, handle_name, device_size);
340357
}
341358

342359
if (pointer_kind == "cuda_ipc") {
@@ -369,17 +386,20 @@ fine::ResourcePtr<ExlaBuffer> create_buffer_from_device_pointer(
369386
}
370387
ptr = maybe_pointer.value();
371388
} else if (pointer_kind == "host_ipc") {
372-
auto tuple =
373-
fine::decode<std::tuple<uint64_t, std::string>>(env, pointer_data);
374-
auto fd = std::get<0>(tuple);
375-
auto memname = std::get<1>(tuple);
389+
auto memname = fine::decode<std::string>(env, pointer_data);
376390
auto device_size = xla::ShapeUtil::ByteSizeOf(shape);
377-
ptr = open_ipc_handle(fd, device_size);
391+
int writable = 0;
392+
auto fd = open_existing_ipc_handle(memname.c_str(), &writable);
393+
if (fd == -1) {
394+
throw std::runtime_error("unable to get fd for IPC handle");
395+
}
396+
ptr = open_ipc_handle(fd, device_size, writable);
378397
if (ptr == nullptr) {
398+
close(fd);
379399
throw std::runtime_error("unable to get pointer for IPC handle");
380400
}
381-
on_delete_callback = [fd, memname, ptr, device_size]() {
382-
close_ipc_handle(fd, ptr, memname.c_str(), device_size);
401+
on_delete_callback = [fd, ptr, device_size]() {
402+
close_imported_ipc_handle(fd, ptr, device_size);
383403
};
384404
} else if (pointer_kind == "local") {
385405
auto ptr_int = fine::decode<int64_t>(env, pointer_data);
@@ -677,6 +697,22 @@ fine::Ok<> start_log_sink(ErlNifEnv *env, ErlNifPid logger_pid) {
677697

678698
FINE_NIF(start_log_sink, 0);
679699

700+
// Test-only NIFs — excluded from production builds.
701+
#ifndef EXLA_PROD
702+
703+
// Writes `data` into the memory at `address + offset`. Intentionally unsafe:
704+
// no bounds checking. The caller is responsible for ensuring the pointer is
705+
// valid and the region is large enough.
706+
fine::Ok<> write_to_pointer(ErlNifEnv *env, uint64_t address,
707+
ErlNifBinary data, uint64_t offset) {
708+
uint8_t *ptr = reinterpret_cast<uint8_t *>(address);
709+
std::memcpy(ptr + offset, data.data, data.size);
710+
return fine::Ok();
711+
}
712+
FINE_NIF(write_to_pointer, 0);
713+
714+
#endif // EXLA_PROD
715+
680716
} // namespace exla
681717

682718
FINE_INIT("Elixir.EXLA.NIF");

exla/c_src/exla/exla_client.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ tsl::StatusOr<ERL_NIF_TERM> ExlaBuffer::ToBinary(ErlNifEnv *env,
6464
return binary_term;
6565
}
6666

67+
void ExlaBuffer::ReplaceBuffer(std::unique_ptr<xla::PjRtBuffer> new_buffer) {
68+
if (buffer_ && !buffer_->IsDeleted()) {
69+
TrackDeallocation();
70+
buffer_->Delete();
71+
}
72+
buffer_ = std::move(new_buffer);
73+
if (client_ && buffer_) {
74+
auto size_or = GetOnDeviceSizeInBytes();
75+
if (size_or.ok()) {
76+
client_->TrackBufferAllocated(device_id(), size_or.value());
77+
}
78+
}
79+
}
80+
6781
tsl::Status ExlaBuffer::Deallocate() {
6882
if (buffer_->IsDeleted()) {
6983
return xla::FailedPrecondition(

exla/c_src/exla/exla_client.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class ExlaBuffer {
4545

4646
void SetClient(ExlaClient *client) { client_ = client; }
4747

48+
// Replace the underlying PjRt buffer with a new one (e.g. an shm-backed
49+
// view). The old buffer is deallocated first so XLA can free its memory.
50+
void ReplaceBuffer(std::unique_ptr<xla::PjRtBuffer> new_buffer);
51+
4852
~ExlaBuffer();
4953

5054
private:

exla/c_src/exla/ipc.cc

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
#include "ipc.h"
22

33
#include <cstdio>
4+
#include <errno.h>
45
#include <fcntl.h>
56
#include <sys/mman.h>
67
#include <sys/stat.h>
78
#include <unistd.h>
89

9-
// Function to create or open a shared memory object and set its size
10-
int get_ipc_handle(const char* memname, size_t memsize) {
11-
int fd = shm_open(memname, O_CREAT | O_RDWR, 0666);
10+
// Create or open a shared memory object and set its size. `mode` is the
11+
// file mode bits forwarded to shm_open(3). The default is chosen in the
12+
// Elixir caller (EXLA.Backend.to_pointer/2).
13+
int get_ipc_handle(const char* memname, size_t memsize, mode_t mode) {
14+
int fd = shm_open(memname, O_CREAT | O_RDWR, mode);
1215
if (fd == -1) {
1316
return -1;
1417
}
@@ -21,16 +24,36 @@ int get_ipc_handle(const char* memname, size_t memsize) {
2124
return fd;
2225
}
2326

24-
// Function to map the shared memory in this process
25-
void* open_ipc_handle(int fd, size_t memsize) {
26-
void* ptr = mmap(NULL, memsize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
27+
// Try O_RDWR first (permissions allow write); fall back to O_RDONLY on EACCES.
28+
int open_existing_ipc_handle(const char* memname, int* out_writable) {
29+
int fd = shm_open(memname, O_RDWR, 0);
30+
if (fd != -1) {
31+
*out_writable = 1;
32+
return fd;
33+
}
34+
if (errno == EACCES) {
35+
fd = shm_open(memname, O_RDONLY, 0);
36+
if (fd != -1) {
37+
*out_writable = 0;
38+
return fd;
39+
}
40+
}
41+
return -1;
42+
}
43+
44+
// MAP_SHARED zero-copy mapping. `writable` adds PROT_WRITE; fd access mode
45+
// must match (O_RDWR for writable, O_RDONLY for read-only).
46+
void* open_ipc_handle(int fd, size_t memsize, int writable) {
47+
int prot = writable ? (PROT_READ | PROT_WRITE) : PROT_READ;
48+
void* ptr = mmap(NULL, memsize, prot, MAP_SHARED, fd, 0);
2749
if (ptr == MAP_FAILED) {
2850
perror("mmap");
2951
return nullptr;
3052
}
3153
return ptr;
3254
}
3355

56+
// Sender cleanup: remove the shm name so the object is freed once all fds close.
3457
int close_ipc_handle(int fd, void* ptr, const char* memname, size_t memsize) {
3558
if (munmap(ptr, memsize) == -1) {
3659
return -1;
@@ -44,3 +67,10 @@ int close_ipc_handle(int fd, void* ptr, const char* memname, size_t memsize) {
4467

4568
return 0;
4669
}
70+
71+
// Receiver cleanup: only munmap + close; the sender owns the shm name/lifetime.
72+
int close_imported_ipc_handle(int fd, void* ptr, size_t memsize) {
73+
munmap(ptr, memsize);
74+
close(fd);
75+
return 0;
76+
}

exla/c_src/exla/ipc.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
11
#pragma once
22

33
#include <cstddef>
4+
#include <sys/types.h>
45

5-
int get_ipc_handle(const char* memname, size_t memsize);
6-
void* open_ipc_handle(int fd, size_t memsize);
6+
// Create a new shm segment, set its size, and return a writable fd.
7+
// `mode` is the file permission bits (e.g. 0o400, 0o600).
8+
int get_ipc_handle(const char* memname, size_t memsize, mode_t mode);
9+
10+
// Open an existing shm segment. Tries O_RDWR first; if EACCES, falls back to
11+
// O_RDONLY. Sets *out_writable to 1 if write access was granted, 0 otherwise.
12+
// Returns -1 on error.
13+
int open_existing_ipc_handle(const char* memname, int* out_writable);
14+
15+
// Map a shm fd with MAP_SHARED. `writable` controls whether PROT_WRITE is
16+
// included; the fd must have been opened with the matching access mode.
17+
void* open_ipc_handle(int fd, size_t memsize, int writable);
18+
19+
// Sender cleanup: munmap + close + shm_unlink.
720
int close_ipc_handle(int fd, void* ptr, const char* memname, size_t memsize);
21+
22+
// Receiver cleanup: munmap + close only (no unlink — sender owns the name).
23+
int close_imported_ipc_handle(int fd, void* ptr, size_t memsize);

exla/lib/exla/backend.ex

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,14 @@ defmodule EXLA.Backend do
8484

8585
@impl true
8686
def to_pointer(%T{data: %B{buffer: buffer}}, opts \\ []) do
87-
opts = Keyword.validate!(opts, mode: :local)
87+
opts = Keyword.validate!(opts, mode: :local, permissions: 0o400)
88+
permissions = opts[:permissions]
89+
90+
unless is_integer(permissions) and permissions >= 0 and permissions <= 0o7777 do
91+
raise ArgumentError,
92+
":permissions must be an integer in the range 0..0o7777 " <>
93+
"(typically an octal literal like 0o400), got: #{inspect(permissions)}"
94+
end
8895

8996
mode =
9097
case {opts[:mode], buffer} do
@@ -108,18 +115,18 @@ defmodule EXLA.Backend do
108115

109116
{mode, _} ->
110117
raise ArgumentError,
111-
"expected one of :local, :cuda_ipc, :host_ipc, got: #{inspect(mode)}"
118+
"expected one of :local, :ipc, got: #{inspect(mode)}"
112119
end
113120

114121
client = EXLA.Client.fetch!(buffer.client_name)
115122

116-
case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do
123+
case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode, permissions) do
117124
{:local, ptr, size} ->
118125
# Pointer is an integer here
119126
%Nx.Pointer{kind: :local, address: ptr, data_size: size}
120127

121-
{:host_ipc, handle_name, fd, size} ->
122-
%Nx.Pointer{kind: :ipc, handle: handle_name, address: fd, data_size: size}
128+
{:host_ipc, handle_name, size} ->
129+
%Nx.Pointer{kind: :ipc, handle: handle_name, data_size: size}
123130

124131
{:cuda_ipc, handle, size} ->
125132
%Nx.Pointer{kind: :ipc, handle: handle, address: buffer.device_id, data_size: size}
@@ -153,8 +160,8 @@ defmodule EXLA.Backend do
153160
{%Nx.Pointer{kind: :local, address: address}, _} ->
154161
{:local, address}
155162

156-
{%Nx.Pointer{kind: :ipc, address: fd, handle: handle}, :host} ->
157-
{:host_ipc, {fd, handle}}
163+
{%Nx.Pointer{kind: :ipc, handle: handle}, :host} ->
164+
{:host_ipc, handle}
158165

159166
{%Nx.Pointer{kind: :ipc, handle: handle}, :cuda} ->
160167
{:cuda_ipc, handle}

exla/lib/exla/nif.ex

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ defmodule EXLA.NIF do
5555
def mlir_get_typespec(_tensor), do: err!()
5656
def mlir_module_to_string(_builder), do: err!()
5757

58-
def get_buffer_device_pointer(_client, _buffer, _pointer_kind), do: err!()
58+
def get_buffer_device_pointer(_client, _buffer, _pointer_kind, _shm_permissions),
59+
do: err!()
5960

6061
def create_buffer_from_device_pointer(
6162
_client,
@@ -96,5 +97,11 @@ defmodule EXLA.NIF do
9697
def encode_local_pid(_pid), do: err!()
9798
def decode_local_pid(_pid_bin), do: err!()
9899

100+
if Mix.env() != :prod do
101+
# Writes `data` into the memory at `address + offset`. Test-only; not
102+
# compiled in production builds.
103+
def write_to_pointer(_address, _data, _offset), do: err!()
104+
end
105+
99106
defp err!(), do: :erlang.nif_error(:undef)
100107
end

0 commit comments

Comments
 (0)