How to fix java.lang.stackoverflow when training your ML model in Spark
I was building ML model using Spark.ml. I just kept creating and adding more features to it, hoping to get better results. At some point, it crashed and gave this error “java.lang.StackOverflow” when performing .fit()
. I was working on an EMR cluster of 2 instances of r3.8xlarge, so the memory should not be the problem here (compared to my dataframe dimension). Reducing the number of rows of the dataframe does not help either.
Well, using .checkpoint()
saves my day!
How does it look like?
The error looks like this:
Py4JJavaError: An error occurred while calling o362.fit.
: java.lang.StackOverflowError
at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:385)
at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:553)
at org.codehaus.janino.CodeContext.flowAnalysis(CodeContext.java:553)
...
What’s the root cause?
The root cause is that the number of processes/plans is too large. Spark executes processes lazily. It constructs DAG (directed acyclical graph) of all the processes it needs to execute before hand; you can view it by your_df.explain(extended=True)
. Now, when I create new features using withColumns(), join(), etc.
, this just grows the size of DAG. I keep doing this process until I call .fit()
, at this point that large DAG will be executed, but unfortunately the graph size is too big to handle. (Note: This problem occurs not only in .fit()
, but also in some other action
commands too.
How to fix it?
The key here is to reduce the size of DAG prior to .fit()
step. We can use checkpoint()
to “shorten” the DAG before the .fit
method. Just use it from time to time when you think the DAG size gets significantly bigger than that of the previous one.
Also, cache()
is useful in similar context too.
Example
Here is what I use in PySpark on EMR cluster.
UPDATE I: I also found that moving .cache()
and .show()
right before .fit()
also helps. For example,
Not sure exactly why, but my theory is having many cache()
or checkpoint()
in between may overwhelm the checkpoint folder, hence some of previous cache()
and checkpoint()
are revoked.
UPDATE 2: I want to confirm what is said in the first comment works too, thanks Frantisek Hajnovic! The technique is to convert from DataFrame into RDD, and back to DataFrame again so that the DAG are processed hence shorter. The length of plan can be significantly different, here are example:
From using .cache()
and .checkpoint()
we have a few hundred lines of the plan, versus using train_df = spark.createDataFrame(train_df.rdd, schema=train_df.schema)
which produces fewer than 10 lines of plan. So, the working code looks like this: