# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any, Callable, TYPE_CHECKING if TYPE_CHECKING: from pyspark.mllib._typing import C, JavaObjectOrPickleDump import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject from py4j.java_collections import JavaArray, JavaList import pyspark.context from pyspark import RDD, SparkContext from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.sql import DataFrame, SparkSession # Hack for support float('inf') in Py4j _old_smart_decode = py4j.protocol.smart_decode _float_str_mapping = { "nan": "NaN", "inf": "Infinity", "-inf": "-Infinity", } def _new_smart_decode(obj: Any) -> str: if isinstance(obj, float): s = str(obj) return _float_str_mapping.get(s, s) return _old_smart_decode(obj) py4j.protocol.smart_decode = _new_smart_decode _picklable_classes = [ "LinkedList", "SparseVector", "DenseVector", "DenseMatrix", "Rating", "LabeledPoint", ] # this will call the MLlib version of pythonToJava() def _to_java_object_rdd(rdd: RDD) -> JavaObject: """Return a JavaRDD of Object by unpickling It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) assert rdd.ctx._jvm is not None return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True) def _py2java(sc: SparkContext, obj: Any) -> JavaObject: """Convert Python object into Java""" if isinstance(obj, RDD): obj = _to_java_object_rdd(obj) elif isinstance(obj, DataFrame): obj = obj._jdf elif isinstance(obj, SparkContext): obj = obj._jsc elif isinstance(obj, list): obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, float, bool, bytes, str)): pass else: data = bytearray(CPickleSerializer().dumps(obj)) assert sc._jvm is not None obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data) return obj def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any: if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD if clsName != "JavaRDD" and clsName.endswith("RDD"): r = r.toJavaRDD() clsName = "JavaRDD" assert sc._jvm is not None if clsName == "JavaRDD": jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r) return RDD(jrdd, sc) if clsName == "Dataset": return DataFrame(r, SparkSession._getActiveSessionOrCreate()) if clsName in _picklable_classes: r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): try: r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) except Py4JJavaError: pass # not pickable if isinstance(r, (bytearray, bytes)): r = CPickleSerializer().loads(bytes(r), encoding=encoding) return r def callJavaFunc( sc: pyspark.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any ) -> Any: """Call Java Function""" java_args = [_py2java(sc, a) for a in args] return _java2py(sc, func(*java_args)) def callMLlibFunc(name: str, *args: Any) -> Any: """Call API in PythonMLLibAPI""" sc = SparkContext.getOrCreate() assert sc._jvm is not None api = getattr(sc._jvm.PythonMLLibAPI(), name) return callJavaFunc(sc, api, *args) class JavaModelWrapper: """ Wrapper for the model in JVM """ def __init__(self, java_model: JavaObject): self._sc = SparkContext.getOrCreate() self._java_model = java_model def __del__(self) -> None: assert self._sc._gateway is not None self._sc._gateway.detach(self._java_model) def call(self, name: str, *a: Any) -> Any: """Call method of java_model""" return callJavaFunc(self._sc, getattr(self._java_model, name), *a) def inherit_doc(cls: "C") -> "C": """ A decorator that makes a class inherit documentation from its parents. """ for name, func in vars(cls).items(): # only inherit docstring for public functions if name.startswith("_"): continue if not func.__doc__: for parent in cls.__bases__: parent_func = getattr(parent, name, None) if parent_func and getattr(parent_func, "__doc__", None): func.__doc__ = parent_func.__doc__ break return cls