Plugging Grain into JAX training: batching + accelerator transfer#

Open in Colab

This guide covers the last mile between a Grain pipeline and a JAX training step: how to batch records into arrays of the right shape, and how to move those batches onto your accelerators efficiently: host-device prefetch, sharding across devices, and distributed-training shards.

# @test {"output": "ignore"}
!pip install grain
# @test {"output": "ignore"}
!pip install tensorflow_datasets
Collecting grain
  Downloading grain-0.2.16-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting absl-py (from grain)
  Downloading absl_py-2.4.0-py3-none-any.whl.metadata (3.3 kB)
Collecting array-record>=0.8.1 (from grain)
  Downloading array_record-0.8.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.1 kB)
Collecting cloudpickle (from grain)
  Downloading cloudpickle-3.1.2-py3-none-any.whl.metadata (7.1 kB)
Requirement already satisfied: etils[epath,epy] in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from grain) (1.14.0)
Collecting numpy (from grain)
  Downloading numpy-2.4.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting protobuf>=5.28.3 (from grain)
  Downloading protobuf-7.34.1-cp310-abi3-manylinux2014_x86_64.whl.metadata (595 bytes)
Requirement already satisfied: fsspec in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[epath,epy]->grain) (2026.4.0)
Requirement already satisfied: typing_extensions in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[epath,epy]->grain) (4.15.0)
Requirement already satisfied: zipp in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[epath,epy]->grain) (3.23.1)
Downloading grain-0.2.16-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (582 kB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/582.9 kB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 582.9/582.9 kB 13.7 MB/s  0:00:00
?25hDownloading array_record-0.8.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (5.0 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/5.0 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.0/5.0 MB 78.1 MB/s  0:00:00
?25h
Downloading protobuf-7.34.1-cp310-abi3-manylinux2014_x86_64.whl (324 kB)
Downloading absl_py-2.4.0-py3-none-any.whl (135 kB)
Downloading cloudpickle-3.1.2-py3-none-any.whl (22 kB)
Downloading numpy-2.4.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.6 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/16.6 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16.6/16.6 MB 170.5 MB/s  0:00:00
?25h
Installing collected packages: protobuf, numpy, cloudpickle, absl-py, array-record, grain
?25l
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━ 5/6 [grain]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━ 5/6 [grain]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6/6 [grain]

Successfully installed absl-py-2.4.0 array-record-0.8.3 cloudpickle-3.1.2 grain-0.2.16 numpy-2.4.4 protobuf-7.34.1
Collecting tensorflow_datasets
  Downloading tensorflow_datasets-4.9.10-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: absl-py in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (2.4.0)
Requirement already satisfied: array_record>=0.5.0 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (0.8.3)
Collecting dm-tree (from tensorflow_datasets)
  Downloading dm_tree-0.1.10-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (2.6 kB)
Requirement already satisfied: etils>=1.9.1 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets) (1.14.0)
Collecting immutabledict (from tensorflow_datasets)
  Downloading immutabledict-4.3.1-py3-none-any.whl.metadata (3.5 kB)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (2.4.4)
Collecting promise (from tensorflow_datasets)
  Downloading promise-2.3.tar.gz (19 kB)
  Installing build dependencies ... ?25l-
 \
 |
 done
?25h  Getting requirements to build wheel ... ?25l-
 done
?25h  Preparing metadata (pyproject.toml) ... ?25l-
 done
?25hRequirement already satisfied: protobuf>=3.20 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (7.34.1)
Requirement already satisfied: psutil in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (7.2.2)
Collecting pyarrow (from tensorflow_datasets)
  Downloading pyarrow-24.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Requirement already satisfied: requests>=2.19.0 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (2.33.1)
Collecting simple_parsing (from tensorflow_datasets)
  Downloading simple_parsing-0.1.8-py3-none-any.whl.metadata (8.1 kB)
Collecting tensorflow-metadata (from tensorflow_datasets)
  Downloading tensorflow_metadata-1.17.3-py3-none-any.whl.metadata (2.5 kB)
Collecting termcolor (from tensorflow_datasets)
  Downloading termcolor-3.3.0-py3-none-any.whl.metadata (6.5 kB)
Collecting toml (from tensorflow_datasets)
  Downloading toml-0.10.2-py2.py3-none-any.whl.metadata (7.1 kB)
Collecting tqdm (from tensorflow_datasets)
  Downloading tqdm-4.67.3-py3-none-any.whl.metadata (57 kB)
Collecting wrapt (from tensorflow_datasets)
  Downloading wrapt-2.1.2-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (7.4 kB)
Collecting einops (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets)
  Downloading einops-0.8.2-py3-none-any.whl.metadata (13 kB)
Requirement already satisfied: fsspec in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets) (2026.4.0)
Requirement already satisfied: typing_extensions in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets) (4.15.0)
Requirement already satisfied: zipp in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets) (3.23.1)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from requests>=2.19.0->tensorflow_datasets) (3.4.7)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from requests>=2.19.0->tensorflow_datasets) (3.14)
Requirement already satisfied: urllib3<3,>=1.26 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from requests>=2.19.0->tensorflow_datasets) (2.7.0)
Requirement already satisfied: certifi>=2023.5.7 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from requests>=2.19.0->tensorflow_datasets) (2026.4.22)
Requirement already satisfied: attrs>=18.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from dm-tree->tensorflow_datasets) (26.1.0)
Requirement already satisfied: six in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from promise->tensorflow_datasets) (1.17.0)
Collecting docstring-parser~=0.15 (from simple_parsing->tensorflow_datasets)
  Downloading docstring_parser-0.18.0-py3-none-any.whl.metadata (3.5 kB)
Collecting googleapis-common-protos<2,>=1.56.4 (from tensorflow-metadata->tensorflow_datasets)
  Downloading googleapis_common_protos-1.75.0-py3-none-any.whl.metadata (8.6 kB)
Downloading tensorflow_datasets-4.9.10-py3-none-any.whl (5.3 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/5.3 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.3/5.3 MB 98.5 MB/s  0:00:00
?25hDownloading dm_tree-0.1.10-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (186 kB)
Downloading wrapt-2.1.2-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl (121 kB)
Downloading einops-0.8.2-py3-none-any.whl (65 kB)
Downloading immutabledict-4.3.1-py3-none-any.whl (5.0 kB)
Downloading pyarrow-24.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (48.9 MB)
?25l   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/48.9 MB ? eta -:--:--
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 48.9/48.9 MB 265.7 MB/s  0:00:00
?25hDownloading simple_parsing-0.1.8-py3-none-any.whl (113 kB)
Downloading docstring_parser-0.18.0-py3-none-any.whl (22 kB)
Downloading tensorflow_metadata-1.17.3-py3-none-any.whl (31 kB)
Downloading googleapis_common_protos-1.75.0-py3-none-any.whl (300 kB)
Downloading termcolor-3.3.0-py3-none-any.whl (7.7 kB)
Downloading toml-0.10.2-py2.py3-none-any.whl (16 kB)
Downloading tqdm-4.67.3-py3-none-any.whl (78 kB)
Building wheels for collected packages: promise
  Building wheel for promise (pyproject.toml) ... ?25l-
 done
?25h  Created wheel for promise: filename=promise-2.3-py3-none-any.whl size=21581 sha256=ead7ecd2e35b3876a91e9450846565c72f3b0f6ffd55e83cb8f5b1d64ef5987c
  Stored in directory: /home/docs/.cache/pip/wheels/e7/e6/28/864bdfee5339dbd6ddcb5a186286a8e217648ec198bdf0097d
Successfully built promise
Installing collected packages: wrapt, tqdm, toml, termcolor, pyarrow, promise, immutabledict, googleapis-common-protos, einops, docstring-parser, tensorflow-metadata, simple_parsing, dm-tree, tensorflow_datasets
?25l
   ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  4/14 [pyarrow]
   ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  4/14 [pyarrow]
   ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  4/14 [pyarrow]
   ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  4/14 [pyarrow]
   ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  4/14 [pyarrow]
   ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  4/14 [pyarrow]
   ━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━  4/14 [pyarrow]
   ━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━  9/14 [docstring-parser]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14/14 [tensorflow_datasets]
Successfully installed dm-tree-0.1.10 docstring-parser-0.18.0 einops-0.8.2 googleapis-common-protos-1.75.0 immutabledict-4.3.1 promise-2.3 pyarrow-24.0.0 simple_parsing-0.1.8 tensorflow-metadata-1.17.3 tensorflow_datasets-4.9.10 termcolor-3.3.0 toml-0.10.2 tqdm-4.67.3 wrapt-2.1.2
import grain
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_datasets as tfds
WARNING:absl:Failed to load jax profiler
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[2], line 2
      1 import grain
----> 2 import jax
      3 import jax.numpy as jnp
      4 import numpy as np
      5 import tensorflow_datasets as tfds

ModuleNotFoundError: No module named 'jax'

1. Minimal end-to-end pipeline#

The shortest pipeline you’d want for JAX training: source -> shuffle -> preprocess -> batch -> iterate -> device_put -> step.

source = tfds.data_source("mnist", split="train")

ds = (
    grain.MapDataset.source(source)
    .seed(42)
    .shuffle()
    .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
                    "label": r["label"]})
    .batch(batch_size=128, drop_remainder=True)  # new leading dim
    .to_iter_dataset()
)

for batch in ds:
    batch = jax.device_put(batch)  # default device
    print(jax.tree.map(lambda x: (x.shape, x.dtype), batch))
    break
{'image': ((128, 28, 28, 1), dtype('float32')), 'label': ((128,), dtype('int32'))}
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...

A few things to notice:

  • batch(...) lives on MapDataset. It stacks PyTree leaves along a new leading axis (here [128, 28, 28, 1] for images, [128] for labels).

  • drop_remainder=True guarantees a static batch shape, which lets jax.jit cache one compiled version of the step.

  • to_iter_dataset() turns the random-access MapDataset into an IterDataset. Do this after any random-access transforms (shuffle, batch, repeat) and before any streaming transforms (prefetch, device_put).

2. Batching tips that matter for JAX#

Stable shapes. JAX recompiles whenever input shapes change. Pair batch(drop_remainder=True) with .repeat() so the loop never produces a short final batch:

ds = (
    grain.MapDataset.source(source)
    .seed(42)
    .shuffle()
    .repeat()  # infinite stream
    .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
                    "label": r["label"]})
    .batch(128, drop_remainder=True)
)
print("length:", len(ds))  # sys.maxsize
length: 72057594037927935

Custom collation. The default batch_fn stacks leaves with np.stack. Pass your own when you need padding, ragged handling, or anything non-uniform:

def pad_collate(items):
    max_len = max(x["tokens"].shape[0] for x in items)
    tokens = np.stack([
        np.pad(x["tokens"], (0, max_len - x["tokens"].shape[0]))
        for x in items
    ])
    return {"tokens": tokens}

# Toy stream of variable-length token sequences.
ragged = grain.MapDataset.source(
    [{"tokens": np.arange(np.random.randint(2, 6))} for _ in range(16)]
)
ragged = ragged.batch(4, batch_fn=pad_collate, drop_remainder=True)
print(ragged[0]["tokens"].shape)
(4, 5)

For variable-length token streams, also look at grain.experimental.batch_and_pad — it pads partial final batches to the requested batch size with a sentinel, so you keep one static shape without dropping data.

3. Moving batches to the accelerator#

There are three options. Pick the lowest tier that meets your needs.

Option A: plain jax.device_put#

Fine for prototyping and small models:

ds = (
    grain.MapDataset.source(source)
    .seed(42).shuffle()
    .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
                    "label": r["label"]})
    .batch(128, drop_remainder=True)
    .to_iter_dataset()
)

for step, batch in zip(range(2), ds):
    batch = jax.device_put(batch)
    print(step, batch["image"].sharding)
0 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
1 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)

The transfer happens on the main thread between every next(...), so the host blocks while the device receives data. On a real training loop this can leave the accelerator idle.

Option B: overlap host work with ThreadPrefetchIterDataset#

Run the pipeline’s CPU work on a background thread so the next batch is ready by the time the device is done with the previous step:

ds = (
    grain.MapDataset.source(source)
    .seed(42).shuffle()
    .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
                    "label": r["label"]})
    .batch(128, drop_remainder=True)
    .to_iter_dataset()
)
ds = grain.experimental.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=4)
ds = ds.map(jax.device_put)  # transfer still on iter thread

first = next(iter(ds))
print(first["image"].shape, first["image"].sharding)
(128, 28, 28, 1) SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)

Option C: two-stage prefetch with grain.experimental.device_put#

The recommended pattern for real training. It runs a CPU buffer and a device-resident buffer, so a batch is already on the accelerator before the step asks for it:

ds = (
    grain.MapDataset.source(source)
    .seed(42).shuffle()
    .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
                    "label": r["label"]})
    .batch(128, drop_remainder=True)
    .to_iter_dataset()
)

ds = grain.experimental.device_put(
    ds=ds,
    device=jax.devices()[0],     # or a Sharding (see below)
    cpu_buffer_size=4,           # batches buffered on host
    device_buffer_size=2,        # batches buffered on device
)

for step, batch in zip(range(2), ds):
    # `batch` is already a jax.Array on-device.
    print(step, batch["image"].sharding)
0 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
1 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)

Under the hood this is just ThreadPrefetch -> map(jax.device_put) -> ThreadPrefetch.

4. Multi-device: sharding a batch across accelerators#

For data-parallel training across all local devices, pass a Sharding to device_put instead of a single device. Each batch is split along its first axis:

devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), axis_names=("data",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))

ds = (
    grain.MapDataset.source(source)
    .seed(42).shuffle().repeat()
    .map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
                    "label": r["label"]})
    .batch(128, drop_remainder=True)
    .to_iter_dataset()
)

ds = grain.experimental.device_put(
    ds=ds,
    device=sharding,
    cpu_buffer_size=4,
    device_buffer_size=2,
)

batch = next(iter(ds))
print(batch["image"].sharding)
NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec('data',), memory_kind=device)

Make sure batch_size is divisible by len(devices) — otherwise the sharding split fails. Inside your train step, decorate with jax.jit and JAX will compile a single SPMD program that handles the per-device slices automatically.

5. Putting it all together#

A realistic single-host, multi-device template:

BATCH = 256
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), axis_names=("data",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))

def preprocess(r):
    return {"image": r["image"].astype(np.float32) / 255.0,
            "label": r["label"]}

ds = (
    grain.MapDataset.source(source)
    .seed(42).shuffle().repeat()
    .map(preprocess)
    .batch(BATCH, drop_remainder=True)
    .to_iter_dataset()
)

ds = grain.experimental.device_put(
    ds=ds, device=sharding,
    cpu_buffer_size=4, device_buffer_size=2,
)

@jax.jit
def train_step(params, batch):
    # Replace with your real loss/update.
    return params + batch["image"].mean()

params = jnp.zeros(())
for step, batch in zip(range(3), ds):
    params = train_step(params, batch)
print("final params:", params)
final params: 0.3906955