2025, Sep 24 23:00
Avoid TensorFlow INVALID_ARGUMENT with tf.data get_single_element when using Python generators
Why INVALID_ARGUMENT: Dataset had more than one element occurs in tf.data, and how to sample generators safely with take(1) or iter to control RAM/VRAM.
When you switch a TensorFlow training pipeline from in-memory arrays to a Python generator to control RAM/VRAM growth, a common gotcha is how you probe or sample from a tf.data.Dataset. If you call get_single_element() on a dataset that has more than one element, TensorFlow will raise INVALID_ARGUMENT: Dataset had more than one element. This guide shows why that happens and how to fix it without changing the generator’s logic.
Problem setup
The pipeline builds a dataset from a generator that yields one example per row. Here is a minimal version that mirrors the original behavior while keeping the naming distinct.
def row_streamer( 
    features_arr: np.typing.NDArray, 
    targets_arr: np.typing.NDArray, 
    win_len, 
    hop 
): 
    n_rows = features_arr.shape[0] 
    onehot_targets = np.stack( 
        [np.flip(targets_arr), targets_arr], 
        axis=1 
    ) 
    win_list = make_windows_batch( 
        feats_mat=features_arr, 
        win=win_len, 
        step=hop 
    ) 
    for idx in range(0, n_rows): 
        yield ( 
            {f"slot_{jj}": arr[idx, :] for jj, arr in enumerate(win_list)}, 
            ( 
                {"embed_out": targets_arr[idx]}, 
                {"class_head": onehot_targets[idx, :]} 
            ) 
        )
The dataset is created with a matching signature. Each element contains a dict of framed inputs and a pair of label dicts.
train_ds = tf.data.Dataset.from_generator( 
    row_streamer, 
    args=[X_train, Y_train, w_len, hop], 
    output_signature=( 
        {f"slot_{jj}": tf.TensorSpec(shape=(w_len,), dtype=tf.float64, name=f"slot_{jj}") 
         for jj in range(number_windows)}, 
        ( 
            {"embed_out": tf.TensorSpec(shape=(), dtype=tf.int32, name="embed_out")}, 
            {"class_head": tf.TensorSpec(shape=(2,), dtype=tf.int32, name="class_head")} 
        ) 
    ) 
)
Why the error appears
get_single_element() is strict by design: it expects the dataset to contain exactly one element. The generator above yields one element per row, so the dataset contains N elements where N is the number of rows. Calling get_single_element() on such a dataset triggers the failure.
INVALID_ARGUMENT: Dataset had more than one element.
This is not a problem with the generator logic or the signature. It’s about how get_single_element() is used.
The fix
If you want a sample for inspection or debugging, either iterate to the first element, or explicitly narrow the dataset to one element before calling get_single_element(). The following options are equivalent in intent and differ only in style.
# Eager style: take the first element via the iterator 
example = next(iter(train_ds)) 
# Convert to a single-element dataset first (compatible with older TF) 
example = tf.data.experimental.get_single_element(train_ds.take(1)) 
# Method form in newer TF releases 
example = train_ds.take(1).get_single_element()
Each snippet ensures that get_single_element() sees exactly one element, which resolves the INVALID_ARGUMENT error.
Additional memory notes
Two adjustments can further reduce memory pressure during input processing. First, consider whether you need float64 for the inputs. It uses twice as much memory as float32, and Keras layers typically default to float32. Second, avoid precomputing all windows for all rows outside the per-row loop, since that undermines the streaming nature of a generator. Compute windows per row inside the loop or build the dataset from rows and create windows in a map stage using tf.signal.frame so that only the current batch’s windows occupy memory.
Why this matters
Large input pipelines often fail for operational rather than algorithmic reasons. Misusing get_single_element() is one such pitfall that surfaces as an INVALID_ARGUMENT error and can distract from the real goal of stabilizing memory. Fixing the sampling pattern, aligning dtypes with model defaults, and pushing windowing into the tf.data pipeline help make training predictable and keep RAM/VRAM use in check.
Takeaways
If a dataset yields multiple elements, don’t call get_single_element() on it directly. Either iterate to the first item or slice with take(1) before extracting. Keep input dtypes consistent with your model’s expectations to avoid unnecessary memory use. Finally, treat windowing as part of the streaming pipeline rather than a precomputation step, so only what’s needed for the current batch lives in memory.