Zhenguo Zhang's Blog
Sharing makes life better
[Python] Add pushback capacity to generators

Generators are a special kind of functions which return values using the keyword yield. The function itself is actually an iterator, allowing traversing its returned values in a for loop.

Sometimes, one may want to push back a obtained value to the generator and read again next time. To do so, one can create a iterable class to wrap generators. Below is an example:

A generator

First, let’s construct a generator.

def my_generator(max=10):
    '''
    a generator to produce integers until max
    '''
    n=0;
    while n <= max:
        yield n;
        n+=1;

The iterable class wrapping a generator

Now, let’s write a class to wrap the generator.

class pushback_wrapper():
    def __init__(self, gen):
        self.gen=gen; # store the generator
        self.stack=[]; # list to store pushed-back values
        self.allowPush=True;

    def __iter__(self):
        '''
        This function makes this class an iterable, because it adds a
        function iter()
        '''
        return(self);

    def __next__(self):
        if self.stack: # return from the stack
            return self.stack.pop(0);
        else: # read from the generator
            try:
                return next(self.gen);
            except: # no more values in generator
                self.allowPush=False;
                raise StopIteration;

    def push_back(self,val):
        '''
        store pushed-back value
        '''
        if self.allowPush:
          self.stack.append(val);

Run the code

Finally, let’s try the code. For this purpose, we will output the values from the generator in batch, and each batch is printed when current value is divisible by 3.

if __name__ == '__main__':
    maxN=10;
    gen=my_generator(maxN);
    myIter=pushback_wrapper(gen);
    record=[];
    pushed=False;
    for i in myIter:
        # printing when reaching 0, 3, 6, ...
        if i % 3 == 0 and not pushed:
            myIter.push_back(i);
            print("batch {0}:".format(int(i/3)),",".join(record));
            record=[];
            pushed=True;
        else:
            pushed=False;
            record.append(str(i));
    if record:
        print("Last: " + ",".join(record));

    print(f"Job done for max={maxN}");
## batch 0: 
## batch 1: 0,1,2
## batch 2: 3,4,5
## batch 3: 6,7,8
## Last: 9,10
## Job done for max=10

As you can see, we used push_back() to store back the value when it is divisible by 3, and it is read again in next loop step.

This same trick can be used for any other iterators.


Last modified on 2020-06-10

Comments powered by Disqus