Plugging Grain into JAX training: batching + accelerator transfer#
This guide covers the last mile between a Grain pipeline and a JAX training step: how to batch records into arrays of the right shape, and how to move those batches onto your accelerators efficiently: host-device prefetch, sharding across devices, and distributed-training shards.
# @test {"output": "ignore"}
!pip install grain
# @test {"output": "ignore"}
!pip install tensorflow_datasets
Collecting grain
Downloading grain-0.2.16-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting absl-py (from grain)
Downloading absl_py-2.4.0-py3-none-any.whl.metadata (3.3 kB)
Collecting array-record>=0.8.1 (from grain)
Downloading array_record-0.8.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.1 kB)
Collecting cloudpickle (from grain)
Downloading cloudpickle-3.1.2-py3-none-any.whl.metadata (7.1 kB)
Requirement already satisfied: etils[epath,epy] in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from grain) (1.14.0)
Collecting numpy (from grain)
Downloading numpy-2.4.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting protobuf>=5.28.3 (from grain)
Downloading protobuf-7.34.1-cp310-abi3-manylinux2014_x86_64.whl.metadata (595 bytes)
Requirement already satisfied: fsspec in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[epath,epy]->grain) (2026.4.0)
Requirement already satisfied: typing_extensions in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[epath,epy]->grain) (4.15.0)
Requirement already satisfied: zipp in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[epath,epy]->grain) (3.23.1)
Downloading grain-0.2.16-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (582 kB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/582.9 kB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 582.9/582.9 kB 13.7 MB/s 0:00:00
?25hDownloading array_record-0.8.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (5.0 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/5.0 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.0/5.0 MB 78.1 MB/s 0:00:00
?25h
Downloading protobuf-7.34.1-cp310-abi3-manylinux2014_x86_64.whl (324 kB)
Downloading absl_py-2.4.0-py3-none-any.whl (135 kB)
Downloading cloudpickle-3.1.2-py3-none-any.whl (22 kB)
Downloading numpy-2.4.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.6 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/16.6 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16.6/16.6 MB 170.5 MB/s 0:00:00
?25h
Installing collected packages: protobuf, numpy, cloudpickle, absl-py, array-record, grain
?25l
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1/6 [numpy]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━ 5/6 [grain]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━ 5/6 [grain]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6/6 [grain]
Successfully installed absl-py-2.4.0 array-record-0.8.3 cloudpickle-3.1.2 grain-0.2.16 numpy-2.4.4 protobuf-7.34.1
Collecting tensorflow_datasets
Downloading tensorflow_datasets-4.9.10-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: absl-py in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (2.4.0)
Requirement already satisfied: array_record>=0.5.0 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (0.8.3)
Collecting dm-tree (from tensorflow_datasets)
Downloading dm_tree-0.1.10-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (2.6 kB)
Requirement already satisfied: etils>=1.9.1 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets) (1.14.0)
Collecting immutabledict (from tensorflow_datasets)
Downloading immutabledict-4.3.1-py3-none-any.whl.metadata (3.5 kB)
Requirement already satisfied: numpy in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (2.4.4)
Collecting promise (from tensorflow_datasets)
Downloading promise-2.3.tar.gz (19 kB)
Installing build dependencies ... ?25l-
\
|
done
?25h Getting requirements to build wheel ... ?25l-
done
?25h Preparing metadata (pyproject.toml) ... ?25l-
done
?25hRequirement already satisfied: protobuf>=3.20 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (7.34.1)
Requirement already satisfied: psutil in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (7.2.2)
Collecting pyarrow (from tensorflow_datasets)
Downloading pyarrow-24.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Requirement already satisfied: requests>=2.19.0 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from tensorflow_datasets) (2.33.1)
Collecting simple_parsing (from tensorflow_datasets)
Downloading simple_parsing-0.1.8-py3-none-any.whl.metadata (8.1 kB)
Collecting tensorflow-metadata (from tensorflow_datasets)
Downloading tensorflow_metadata-1.17.3-py3-none-any.whl.metadata (2.5 kB)
Collecting termcolor (from tensorflow_datasets)
Downloading termcolor-3.3.0-py3-none-any.whl.metadata (6.5 kB)
Collecting toml (from tensorflow_datasets)
Downloading toml-0.10.2-py2.py3-none-any.whl.metadata (7.1 kB)
Collecting tqdm (from tensorflow_datasets)
Downloading tqdm-4.67.3-py3-none-any.whl.metadata (57 kB)
Collecting wrapt (from tensorflow_datasets)
Downloading wrapt-2.1.2-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (7.4 kB)
Collecting einops (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets)
Downloading einops-0.8.2-py3-none-any.whl.metadata (13 kB)
Requirement already satisfied: fsspec in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets) (2026.4.0)
Requirement already satisfied: typing_extensions in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets) (4.15.0)
Requirement already satisfied: zipp in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from etils[edc,enp,epath,epy,etree]>=1.9.1; python_version >= "3.11"->tensorflow_datasets) (3.23.1)
Requirement already satisfied: charset_normalizer<4,>=2 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from requests>=2.19.0->tensorflow_datasets) (3.4.7)
Requirement already satisfied: idna<4,>=2.5 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from requests>=2.19.0->tensorflow_datasets) (3.14)
Requirement already satisfied: urllib3<3,>=1.26 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from requests>=2.19.0->tensorflow_datasets) (2.7.0)
Requirement already satisfied: certifi>=2023.5.7 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from requests>=2.19.0->tensorflow_datasets) (2026.4.22)
Requirement already satisfied: attrs>=18.2.0 in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from dm-tree->tensorflow_datasets) (26.1.0)
Requirement already satisfied: six in /home/docs/checkouts/readthedocs.org/user_builds/google-grain/envs/1302/lib/python3.12/site-packages (from promise->tensorflow_datasets) (1.17.0)
Collecting docstring-parser~=0.15 (from simple_parsing->tensorflow_datasets)
Downloading docstring_parser-0.18.0-py3-none-any.whl.metadata (3.5 kB)
Collecting googleapis-common-protos<2,>=1.56.4 (from tensorflow-metadata->tensorflow_datasets)
Downloading googleapis_common_protos-1.75.0-py3-none-any.whl.metadata (8.6 kB)
Downloading tensorflow_datasets-4.9.10-py3-none-any.whl (5.3 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/5.3 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.3/5.3 MB 98.5 MB/s 0:00:00
?25hDownloading dm_tree-0.1.10-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (186 kB)
Downloading wrapt-2.1.2-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl (121 kB)
Downloading einops-0.8.2-py3-none-any.whl (65 kB)
Downloading immutabledict-4.3.1-py3-none-any.whl (5.0 kB)
Downloading pyarrow-24.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (48.9 MB)
?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/48.9 MB ? eta -:--:--
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 48.9/48.9 MB 265.7 MB/s 0:00:00
?25hDownloading simple_parsing-0.1.8-py3-none-any.whl (113 kB)
Downloading docstring_parser-0.18.0-py3-none-any.whl (22 kB)
Downloading tensorflow_metadata-1.17.3-py3-none-any.whl (31 kB)
Downloading googleapis_common_protos-1.75.0-py3-none-any.whl (300 kB)
Downloading termcolor-3.3.0-py3-none-any.whl (7.7 kB)
Downloading toml-0.10.2-py2.py3-none-any.whl (16 kB)
Downloading tqdm-4.67.3-py3-none-any.whl (78 kB)
Building wheels for collected packages: promise
Building wheel for promise (pyproject.toml) ... ?25l-
done
?25h Created wheel for promise: filename=promise-2.3-py3-none-any.whl size=21581 sha256=ead7ecd2e35b3876a91e9450846565c72f3b0f6ffd55e83cb8f5b1d64ef5987c
Stored in directory: /home/docs/.cache/pip/wheels/e7/e6/28/864bdfee5339dbd6ddcb5a186286a8e217648ec198bdf0097d
Successfully built promise
Installing collected packages: wrapt, tqdm, toml, termcolor, pyarrow, promise, immutabledict, googleapis-common-protos, einops, docstring-parser, tensorflow-metadata, simple_parsing, dm-tree, tensorflow_datasets
?25l
━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/14 [pyarrow]
━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/14 [pyarrow]
━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/14 [pyarrow]
━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/14 [pyarrow]
━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/14 [pyarrow]
━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/14 [pyarrow]
━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4/14 [pyarrow]
━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━ 9/14 [docstring-parser]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 13/14 [tensorflow_datasets]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14/14 [tensorflow_datasets]
Successfully installed dm-tree-0.1.10 docstring-parser-0.18.0 einops-0.8.2 googleapis-common-protos-1.75.0 immutabledict-4.3.1 promise-2.3 pyarrow-24.0.0 simple_parsing-0.1.8 tensorflow-metadata-1.17.3 tensorflow_datasets-4.9.10 termcolor-3.3.0 toml-0.10.2 tqdm-4.67.3 wrapt-2.1.2
import grain
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_datasets as tfds
WARNING:absl:Failed to load jax profiler
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[2], line 2
1 import grain
----> 2 import jax
3 import jax.numpy as jnp
4 import numpy as np
5 import tensorflow_datasets as tfds
ModuleNotFoundError: No module named 'jax'
1. Minimal end-to-end pipeline#
The shortest pipeline you’d want for JAX training: source -> shuffle -> preprocess -> batch -> iterate -> device_put -> step.
source = tfds.data_source("mnist", split="train")
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"]})
.batch(batch_size=128, drop_remainder=True) # new leading dim
.to_iter_dataset()
)
for batch in ds:
batch = jax.device_put(batch) # default device
print(jax.tree.map(lambda x: (x.shape, x.dtype), batch))
break
{'image': ((128, 28, 28, 1), dtype('float32')), 'label': ((128,), dtype('int32'))}
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...
A few things to notice:
batch(...)lives onMapDataset. It stacks PyTree leaves along a new leading axis (here[128, 28, 28, 1]for images,[128]for labels).drop_remainder=Trueguarantees a static batch shape, which letsjax.jitcache one compiled version of the step.to_iter_dataset()turns the random-accessMapDatasetinto anIterDataset. Do this after any random-access transforms (shuffle, batch, repeat) and before any streaming transforms (prefetch,device_put).
2. Batching tips that matter for JAX#
Stable shapes. JAX recompiles whenever input shapes change. Pair batch(drop_remainder=True) with .repeat() so the loop never produces a short final batch:
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.repeat() # infinite stream
.map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"]})
.batch(128, drop_remainder=True)
)
print("length:", len(ds)) # sys.maxsize
length: 72057594037927935
Custom collation. The default batch_fn stacks leaves with np.stack. Pass your own when you need padding, ragged handling, or anything non-uniform:
def pad_collate(items):
max_len = max(x["tokens"].shape[0] for x in items)
tokens = np.stack([
np.pad(x["tokens"], (0, max_len - x["tokens"].shape[0]))
for x in items
])
return {"tokens": tokens}
# Toy stream of variable-length token sequences.
ragged = grain.MapDataset.source(
[{"tokens": np.arange(np.random.randint(2, 6))} for _ in range(16)]
)
ragged = ragged.batch(4, batch_fn=pad_collate, drop_remainder=True)
print(ragged[0]["tokens"].shape)
(4, 5)
For variable-length token streams, also look at grain.experimental.batch_and_pad — it pads partial final batches to the requested batch size with a sentinel, so you keep one static shape without dropping data.
3. Moving batches to the accelerator#
There are three options. Pick the lowest tier that meets your needs.
Option A: plain jax.device_put#
Fine for prototyping and small models:
ds = (
grain.MapDataset.source(source)
.seed(42).shuffle()
.map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"]})
.batch(128, drop_remainder=True)
.to_iter_dataset()
)
for step, batch in zip(range(2), ds):
batch = jax.device_put(batch)
print(step, batch["image"].sharding)
0 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
1 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
The transfer happens on the main thread between every next(...), so the host blocks while the device receives data. On a real training loop this can leave the accelerator idle.
Option B: overlap host work with ThreadPrefetchIterDataset#
Run the pipeline’s CPU work on a background thread so the next batch is ready by the time the device is done with the previous step:
ds = (
grain.MapDataset.source(source)
.seed(42).shuffle()
.map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"]})
.batch(128, drop_remainder=True)
.to_iter_dataset()
)
ds = grain.experimental.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=4)
ds = ds.map(jax.device_put) # transfer still on iter thread
first = next(iter(ds))
print(first["image"].shape, first["image"].sharding)
(128, 28, 28, 1) SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
Option C: two-stage prefetch with grain.experimental.device_put#
The recommended pattern for real training. It runs a CPU buffer and a device-resident buffer, so a batch is already on the accelerator before the step asks for it:
ds = (
grain.MapDataset.source(source)
.seed(42).shuffle()
.map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"]})
.batch(128, drop_remainder=True)
.to_iter_dataset()
)
ds = grain.experimental.device_put(
ds=ds,
device=jax.devices()[0], # or a Sharding (see below)
cpu_buffer_size=4, # batches buffered on host
device_buffer_size=2, # batches buffered on device
)
for step, batch in zip(range(2), ds):
# `batch` is already a jax.Array on-device.
print(step, batch["image"].sharding)
0 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
1 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
Under the hood this is just ThreadPrefetch -> map(jax.device_put) -> ThreadPrefetch.
4. Multi-device: sharding a batch across accelerators#
For data-parallel training across all local devices, pass a Sharding to device_put instead of a single device. Each batch is split along its first axis:
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), axis_names=("data",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
ds = (
grain.MapDataset.source(source)
.seed(42).shuffle().repeat()
.map(lambda r: {"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"]})
.batch(128, drop_remainder=True)
.to_iter_dataset()
)
ds = grain.experimental.device_put(
ds=ds,
device=sharding,
cpu_buffer_size=4,
device_buffer_size=2,
)
batch = next(iter(ds))
print(batch["image"].sharding)
NamedSharding(mesh=Mesh('data': 1, axis_types=(Auto,)), spec=PartitionSpec('data',), memory_kind=device)
Make sure batch_size is divisible by len(devices) — otherwise the sharding split fails. Inside your train step, decorate with jax.jit and JAX will compile a single SPMD program that handles the per-device slices automatically.
5. Putting it all together#
A realistic single-host, multi-device template:
BATCH = 256
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), axis_names=("data",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
def preprocess(r):
return {"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"]}
ds = (
grain.MapDataset.source(source)
.seed(42).shuffle().repeat()
.map(preprocess)
.batch(BATCH, drop_remainder=True)
.to_iter_dataset()
)
ds = grain.experimental.device_put(
ds=ds, device=sharding,
cpu_buffer_size=4, device_buffer_size=2,
)
@jax.jit
def train_step(params, batch):
# Replace with your real loss/update.
return params + batch["image"].mean()
params = jnp.zeros(())
for step, batch in zip(range(3), ds):
params = train_step(params, batch)
print("final params:", params)
final params: 0.3906955