Spark 从零到开发(五)初识Spark SQL

本人花费半年的时间总结的《Java面试指南》已拿腾讯等大厂offer,已开源在github ,欢迎star!

本文GitHub https://github.com/OUYANGSIHAI/JavaInterview 已收录,这是我花了6个月总结的一线大厂Java面试总结,本人已拿大厂offer,欢迎star

原文链接:blog.ouyangsihai.cn >> Spark 从零到开发(五)初识Spark SQL

Spark SQL是用于结构化数据处理的Spark模块。 与基本的Spark RDD API不同,Spark SQL提供的接口为Spark提供了有关数据结构和正在执行的计算的更多信息。 在内部,Spark SQL使用此额外信息来执行额外的优化。 有几种与Spark SQL交互的方法,包括SQL和Dataset API。 在计算结果时,使用相同的执行引擎,与您用于表达计算的API /语言无关。 这种统一意味着开发人员可以轻松地在不同的API之间来回切换,从而提供表达给定转换的最自然的方式。

Spark SQL的一个用途是执行SQL查询。 Spark SQL还可用于从现有Hive安装中读取数据。从另一种编程语言中运行SQL时,结果将作为数据集/数据框返回。 您还可以使用命令行或JDBC / ODBC与SQL接口进行交互。

数据集和数据框架

数据集是分布式数据集合。 Dataset是Spark 1.6中添加的一个新接口,它提供了RDD的优势(强类型,使用强大的lambda函数的能力)以及Spark SQL优化执行引擎的优点。数据集可以从JVM对象构造,然后使用功能转换(map,flatMap,filter等)进行操作。

1. 入门

Spark中所有功能的入口点都是类。要创建基本的 SparkSession,只需使用 SparkSession.builder()

1234567
import org.apache.spark.sql.SparkSession; SparkSession spark = SparkSession  .builder()  .appName("Java Spark SQL basic example")  .config("spark.some.config.option", "some-value")  .getOrCreate();

import org.apache.spark.sql.SparkSession;

SparkSession spark = SparkSession
.builder()
.appName(“Java Spark SQL basic example”)
.config(“spark.some.config.option”, “some-value”)
.getOrCreate();

1.1 创建DataFrames

使用SparkSession,应用程序可以从现有RDD,Hive表或Spark数据源创建DataFrame。

基于JSON文件的内容创建DataFrame的示例:

1234567891011121314
import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row; DatasetRow df = spark.read().json("examples/src/main/resources/people.json"); // Displays the content of the DataFrame to stdoutdf.show();// +----+-------+// | age|   name|// +----+-------+// |null|Michael|// |  30|   Andy|// |  19| Justin|// +----+-------+

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

DatasetRow df = spark.read().json(“examples/src/main/resources/people.json”);

// Displays the content of the DataFrame to stdout
df.show();
// +—-+——-+
// | age| name|
// +—-+——-+
// |null|Michael|
// | 30| Andy|
// | 19| Justin|
// +—-+——-+

1.2 数据集操作

数据集进行结构化数据处理的基本示例:

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
// col("...") is preferable to df.col("...")import static org.apache.spark.sql.functions.col; // Print the schema in a tree format(打印元数据)df.printSchema();// root// |-- age: long (nullable = true)// |-- name: string (nullable = true) // Select only the "name" column(查找name 这列)df.select("name").show();// +-------+// |   name|// +-------+// |Michael|// |   Andy|// | Justin|// +-------+ // Select everybody, but increment the age by 1 (查找name age列,age列加一)df.select(col("name"), col("age").plus(1)).show();// +-------+---------+// |   name|(age + 1)|// +-------+---------+// |Michael|     null|// |   Andy|       31|// | Justin|       20|// +-------+---------+ // Select people older than 21  (查找age大于21的数据)df.filter(col("age").gt(21)).show();// +---+----+// |age|name|// +---+----+// | 30|Andy|// +---+----+ // Count people by agedf.groupBy("age").count().show(); (分组查询:列名age数量统计)// +----+-----+// | age|count|// +----+-----+// |  19|    1|// |null|    1|// |  30|    1|// +----+-----+

// col(“…”) is preferable to df.col(“…”)
import static org.apache.spark.sql.functions.col;

// Print the schema in a tree format(打印元数据)
df.printSchema();
// root
// |– age: long (nullable = true)
// |– name: string (nullable = true)

// Select only the “name” column(查找name 这列)
df.select(“name”).show();
// +——-+
// | name|
// +——-+
// |Michael|
// | Andy|
// | Justin|
// +——-+

// Select everybody, but increment the age by 1 (查找name age列,age列加一)
df.select(col(“name”), col(“age”).plus(1)).show();
// +——-+———+
// | name|(age + 1)|
// +——-+———+
// |Michael| null|
// | Andy| 31|
// | Justin| 20|
// +——-+———+

// Select people older than 21 (查找age大于21的数据)
df.filter(col(“age”).gt(21)).show();
// +—+—-+
// |age|name|
// +—+—-+
// | 30|Andy|
// +—+—-+

// Count people by age
df.groupBy(“age”).count().show(); (分组查询:列名age数量统计)
// +—-+—–+
// | age|count|
// +—-+—–+
// | 19| 1|
// |null| 1|
// | 30| 1|
// +—-+—–+

1.3 以编程方式来查询sql
123456789101112131415
import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row; // Register the DataFrame as a SQL temporary viewdf.createOrReplaceTempView("people"); DatasetRow sqlDF = spark.sql("SELECT * FROM people");sqlDF.show();// +----+-------+// | age|   name|// +----+-------+// |null|Michael|// |  30|   Andy|// |  19| Justin|// +----+-------+

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

// Register the DataFrame as a SQL temporary view
df.createOrReplaceTempView(“people”);

DatasetRow sqlDF = spark.sql(“SELECT * FROM people”);
sqlDF.show();
// +—-+——-+
// | age| name|
// +—-+——-+
// |null|Michael|
// | 30| Andy|
// | 19| Justin|
// +—-+——-+

1.4 全局临时视图

Spark SQL中的临时视图是会话范围的,如果创建它的会话终止,它将消失。 如果您希望拥有一个在所有会话之间共享的临时视图并保持活动状态,直到Spark应用程序终止,您可以创建一个全局临时视图。 全局临时视图与系统保留的数据库global_temp绑定,我们必须使用限定名称来引用它,例如 SELECT * FROM global_temp.view1

12345678910111213141516171819202122
// Register the DataFrame as a global temporary view(创建一个全局临时视图对象)df.createGlobalTempView("people"); // Global temporary view is tied to a system preserved database `global_temp`(查询名字为people的全局临时视图)spark.sql("SELECT * FROM global_temp.people").show();// +----+-------+// | age|   name|// +----+-------+// |null|Michael|// |  30|   Andy|// |  19| Justin|// +----+-------+ // Global temporary view is cross-sessionspark.newSession().sql("SELECT * FROM global_temp.people").show();// +----+-------+// | age|   name|// +----+-------+// |null|Michael|// |  30|   Andy|// |  19| Justin|// +----+-------+

// Register the DataFrame as a global temporary view(创建一个全局临时视图对象)
df.createGlobalTempView(“people”);

// Global temporary view is tied to a system preserved database global_temp(查询名字为people的全局临时视图)
spark.sql(“SELECT * FROM global_temp.people”).show();
// +—-+——-+
// | age| name|
// +—-+——-+
// |null|Michael|
// | 30| Andy|
// | 19| Justin|
// +—-+——-+

// Global temporary view is cross-session
spark.newSession().sql(“SELECT * FROM global_temp.people”).show();
// +—-+——-+
// | age| name|
// +—-+——-+
// |null|Michael|
// | 30| Andy|
// | 19| Justin|
// +—-+——-+

1.5 创建数据集

数据集与RDD类似,但是,它们不使用Java序列化或Kryo,而是使用专门的编码器来序列化对象以便通过网络进行处理或传输。 虽然编码器和标准序列化都负责将对象转换为字节,但编码器是动态生成的代码,并使用一种格式,允许Spark执行许多操作,如过滤,排序和散列,而无需将字节反序列化为对象。

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
import java.util.Arrays;import java.util.Collections;import java.io.Serializable; import org.apache.spark.api.java.function.MapFunction;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders; public static class Person implements Serializable {  private String name;  private int age;   public String getName() {    return name;  }   public void setName(String name) {    this.name = name;  }   public int getAge() {    return age;  }   public void setAge(int age) {    this.age = age;  }} // Create an instance of a Bean classPerson person = new Person();person.setName("Andy");person.setAge(32); // Encoders are created for Java beans(对javabean进行编码)EncoderPerson personEncoder = Encoders.bean(Person.class);DatasetPerson javaBeanDS = spark.createDataset(  Collections.singletonList(person),  personEncoder);javaBeanDS.show();// +---+----+// |age|name|// +---+----+// | 32|Andy|// +---+----+ // Encoders for most common types are provided in class Encoders()EncoderInteger integerEncoder = Encoders.INT();DatasetInteger primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder);DatasetInteger transformedDS = primitiveDS.map(    (MapFunctionInteger, Integer) value - value + 1,    integerEncoder);transformedDS.collect(); // Returns [2, 3, 4] // DataFrames can be converted to a Dataset by providing a class. Mapping based on name(DataFrames可以基于映射的名字将一个类转换成数据集)String path = "examples/src/main/resources/people.json";DatasetPerson peopleDS = spark.read().json(path).as(personEncoder);peopleDS.show();// +----+-------+// | age|   name|// +----+-------+// |null|Michael|// |  30|   Andy|// |  19| Justin|// +----+-------+

import java.util.Arrays;
import java.util.Collections;
import java.io.Serializable;

import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;

public static class Person implements Serializable {
private String name;
private int age;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public int getAge() {
return age;
}

public void setAge(int age) {
this.age = age;
}
}

// Create an instance of a Bean class
Person person = new Person();
person.setName(“Andy”);
person.setAge(32);

// Encoders are created for Java beans(对javabean进行编码)
EncoderPerson personEncoder = Encoders.bean(Person.class);
DatasetPerson javaBeanDS = spark.createDataset(
Collections.singletonList(person),
personEncoder
);
javaBeanDS.show();
// +—+—-+
// |age|name|
// +—+—-+
// | 32|Andy|
// +—+—-+

// Encoders for most common types are provided in class Encoders()
EncoderInteger integerEncoder = Encoders.INT();
DatasetInteger primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder);
DatasetInteger transformedDS = primitiveDS.map(
(MapFunctionInteger, Integer) value - value + 1,
integerEncoder);
transformedDS.collect(); // Returns [2, 3, 4]

// DataFrames can be converted to a Dataset by providing a class. Mapping based on name(DataFrames可以基于映射的名字将一个类转换成数据集)
String path = “examples/src/main/resources/people.json”;
DatasetPerson peopleDS = spark.read().json(path).as(personEncoder);
peopleDS.show();
// +—-+——-+
// | age| name|
// +—-+——-+
// |null|Michael|
// | 30| Andy|
// | 19| Justin|
// +—-+——-+

1.6 与RDD交互

Spark SQL支持两种不同的方法将现有RDD转换为数据集。

第一种方法使用反射来推断包含特定类型对象的RDD的模式。 这种基于反射的方法可以提供更简洁的代码.

第二种方法是通过编程接口,允许您构建模式,然后将其应用于现有RDD。

1.6.1 使用反射模式

Spark SQL支持自动将JavaBeans的RDD转换为DataFrame。 使用反射获得的BeanInfo定义了表的模式。 目前,Spark SQL不支持包含Map字段的JavaBean。 但是支持嵌套的JavaBeans和List或Array字段。 您可以通过创建实现Serializable的类来创建JavaBean,并为其所有字段设置getter和setter。

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.function.Function;import org.apache.spark.api.java.function.MapFunction;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders; // Create an RDD of Person objects from a text fileJavaRDDPerson peopleRDD = spark.read()  .textFile("examples/src/main/resources/people.txt")  .javaRDD()  .map(line - {    String[] parts = line.split(",");    Person person = new Person();    person.setName(parts[0]);    person.setAge(Integer.parseInt(parts[1].trim()));    return person;  }); // Apply a schema to an RDD of JavaBeans to get a DataFrameDatasetRow peopleDF = spark.createDataFrame(peopleRDD, Person.class);// Register the DataFrame as a temporary viewpeopleDF.createOrReplaceTempView("people"); // SQL statements can be run by using the sql methods provided by sparkDatasetRow teenagersDF = spark.sql("SELECT name FROM people WHERE age BETWEEN 13 AND 19"); // The columns of a row in the result can be accessed by field indexEncoderString stringEncoder = Encoders.STRING();DatasetString teenagerNamesByIndexDF = teenagersDF.map(    (MapFunctionRow, String) row - "Name: " + row.getString(0),    stringEncoder);teenagerNamesByIndexDF.show();// +------------+// |       value|// +------------+// |Name: Justin|// +------------+ // or by field nameDatasetString teenagerNamesByFieldDF = teenagersDF.map(    (MapFunctionRow, String) row - "Name: " + row.StringgetAs("name"),    stringEncoder);teenagerNamesByFieldDF.show();// +------------+// |       value|// +------------+// |Name: Justin|// +------------+

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;

// Create an RDD of Person objects from a text file
JavaRDDPerson peopleRDD = spark.read()
.textFile(“examples/src/main/resources/people.txt”)
.javaRDD()
.map(line - {
String[] parts = line.split(“,”);
Person person = new Person();
person.setName(parts[0]);
person.setAge(Integer.parseInt(parts[1].trim()));
return person;
});

// Apply a schema to an RDD of JavaBeans to get a DataFrame
DatasetRow peopleDF = spark.createDataFrame(peopleRDD, Person.class);
// Register the DataFrame as a temporary view
peopleDF.createOrReplaceTempView(“people”);

// SQL statements can be run by using the sql methods provided by spark
DatasetRow teenagersDF = spark.sql(“SELECT name FROM people WHERE age BETWEEN 13 AND 19”);

// The columns of a row in the result can be accessed by field index
EncoderString stringEncoder = Encoders.STRING();
DatasetString teenagerNamesByIndexDF = teenagersDF.map(
(MapFunctionRow, String) row - “Name: “ + row.getString(0),
stringEncoder);
teenagerNamesByIndexDF.show();
// +————+
// | value|
// +————+
// |Name: Justin|
// +————+

// or by field name
DatasetString teenagerNamesByFieldDF = teenagersDF.map(
(MapFunctionRow, String) row - “Name: “ + row.StringgetAs(“name”),
stringEncoder);
teenagerNamesByFieldDF.show();
// +————+
// | value|
// +————+
// |Name: Justin|
// +————+

1.6.2 编程方式模式

如果无法提前定义JavaBean类(例如,记录的结构以字符串形式编码,或者文本数据集将被解析,并且字段将针对不同的用户进行不同的投影),则可以通过编程方式创建DatasetRow 有三个步骤。

  1. 从原始RDD创建行的RDD;
  2. 创建由与步骤1中创建的RDD中的行结构匹配的StructType表示的模式。
  3. 通过SparkSession提供的createDataFrame方法将模式应用于行的RDD。
    123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
    import java.util.ArrayList;import java.util.List; import org.apache.spark.api.java.JavaRDD;import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType; // Create an RDD(创建一个RDD)JavaRDDString peopleRDD = spark.sparkContext()  .textFile("examples/src/main/resources/people.txt", 1)  .toJavaRDD(); // The schema is encoded in a stringString schemaString = "name age"; // Generate the schema based on the string of schemaListStructField fields = new ArrayList();for (String fieldName : schemaString.split(" ")) {  StructField field = DataTypes.createStructField(fieldName, DataTypes.StringType, true);  fields.add(field);}StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to RowsJavaRDDRow rowRDD = peopleRDD.map((FunctionString, Row) record - {  String[] attributes = record.split(",");  return RowFactory.create(attributes[0], attributes[1].trim());}); // Apply the schema to the RDDDatasetRow peopleDataFrame = spark.createDataFrame(rowRDD, schema); // Creates a temporary view using the DataFramepeopleDataFrame.createOrReplaceTempView("people"); // SQL can be run over a temporary view created using DataFramesDatasetRow results = spark.sql("SELECT name FROM people"); // The results of SQL queries are DataFrames and support all the normal RDD operations// The columns of a row in the result can be accessed by field index or by field nameDatasetString namesDS = results.map(    (MapFunctionRow, String) row - "Name: " + row.getString(0),    Encoders.STRING());namesDS.show();// +-------------+// |        value|// +-------------+// |Name: Michael|// |   Name: Andy|// | Name: Justin|// +-------------+

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

// Create an RDD(创建一个RDD)
JavaRDDString peopleRDD = spark.sparkContext()
.textFile(“examples/src/main/resources/people.txt”, 1)
.toJavaRDD();

// The schema is encoded in a string
String schemaString = “name age”;

// Generate the schema based on the string of schema
ListStructField fields = new ArrayList();
for (String fieldName : schemaString.split(“ “)) {
StructField field = DataTypes.createStructField(fieldName, DataTypes.StringType, true);
fields.add(field);
}
StructType schema = DataTypes.createStructType(fields);

// Convert records of the RDD (people) to Rows
JavaRDDRow rowRDD = peopleRDD.map((FunctionString, Row) record - {
String[] attributes = record.split(“,”);
return RowFactory.create(attributes[0], attributes[1].trim());
});

// Apply the schema to the RDD
DatasetRow peopleDataFrame = spark.createDataFrame(rowRDD, schema);

// Creates a temporary view using the DataFrame
peopleDataFrame.createOrReplaceTempView(“people”);

// SQL can be run over a temporary view created using DataFrames
DatasetRow results = spark.sql(“SELECT name FROM people”);

// The results of SQL queries are DataFrames and support all the normal RDD operations
// The columns of a row in the result can be accessed by field index or by field name
DatasetString namesDS = results.map(
(MapFunctionRow, String) row - “Name: “ + row.getString(0),
Encoders.STRING());
namesDS.show();
// +————-+
// | value|
// +————-+
// |Name: Michael|
// | Name: Andy|
// | Name: Justin|
// +————-+

1.7 聚合

内置的DataFrames函数提供了常见的聚合,例如count(),countDistinct(),avg(),max(),min()等。虽然这些函数是为DataFrames设计的,但Spark SQL也有类型安全的版本 其中一些在Scala和Java中使用强类型数据集。 此外,用户不限于预定义的聚合函数,并且可以创建自己的聚合函数。

1.7.1 无用户定义的聚合函数

用户必须扩展UserDefinedAggregateFunction抽象类以实现自定义无类型聚合函数。

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
import java.util.ArrayList;import java.util.List; import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.expressions.MutableAggregationBuffer;import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;import org.apache.spark.sql.types.DataType;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType; public static class MyAverage extends UserDefinedAggregateFunction {   private StructType inputSchema;  private StructType bufferSchema;   public MyAverage() {    ListStructField inputFields = new ArrayList();    inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));    inputSchema = DataTypes.createStructType(inputFields);     ListStructField bufferFields = new ArrayList();    bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));    bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));    bufferSchema = DataTypes.createStructType(bufferFields);  }  // Data types of input arguments of this aggregate function  public StructType inputSchema() {    return inputSchema;  }  // Data types of values in the aggregation buffer  public StructType bufferSchema() {    return bufferSchema;  }  // The data type of the returned value  public DataType dataType() {    return DataTypes.DoubleType;  }  // Whether this function always returns the same output on the identical input  public boolean deterministic() {    return true;  }  // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to  // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides  // the opportunity to update its values. Note that arrays and maps inside the buffer are still  // immutable.  public void initialize(MutableAggregationBuffer buffer) {    buffer.update(0, 0L);    buffer.update(1, 0L);  }  // Updates the given aggregation buffer `buffer` with new input data from `input`  public void update(MutableAggregationBuffer buffer, Row input) {    if (!input.isNullAt(0)) {      long updatedSum = buffer.getLong(0) + input.getLong(0);      long updatedCount = buffer.getLong(1) + 1;      buffer.update(0, updatedSum);      buffer.update(1, updatedCount);    }  }  // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`  public void merge(MutableAggregationBuffer buffer1, Row buffer2) {    long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);    long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);    buffer1.update(0, mergedSum);    buffer1.update(1, mergedCount);  }  // Calculates the final result  public Double evaluate(Row buffer) {    return ((double) buffer.getLong(0)) / buffer.getLong(1);  }} // Register the function to access itspark.udf().register("myAverage", new MyAverage()); DatasetRow df = spark.read().json("examples/src/main/resources/employees.json");df.createOrReplaceTempView("employees");df.show();// +-------+------+// |   name|salary|// +-------+------+// |Michael|  3000|// |   Andy|  4500|// | Justin|  3500|// |  Berta|  4000|// +-------+------+ DatasetRow result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");result.show();// +--------------+// |average_salary|// +--------------+// |        3750.0|// +--------------+

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public static class MyAverage extends UserDefinedAggregateFunction {

private StructType inputSchema;
private StructType bufferSchema;

public MyAverage() {
ListStructField inputFields = new ArrayList();
inputFields.add(DataTypes.createStructField(“inputColumn”, DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);


ListStructField bufferFields = new ArrayList();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);

}
// Data types of input arguments of this aggregate function
public StructType inputSchema() {
return inputSchema;
}
// Data types of values in the aggregation buffer
public StructType bufferSchema() {
return bufferSchema;
}
// The data type of the returned value
public DataType dataType() {
return DataTypes.DoubleType;
}
// Whether this function always returns the same output on the identical input
public boolean deterministic() {
return true;
}
// Initializes the given aggregation buffer. The buffer itself is a Row that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
}
// Updates the given aggregation buffer buffer with new input data from input
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
}
// Merges two aggregation buffers and stores the updated buffer values back to buffer1
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
}
// Calculates the final result
public Double evaluate(Row buffer) {
return ((double) buffer.getLong(0)) / buffer.getLong(1);
}
}

// Register the function to access it
spark.udf().register(“myAverage”, new MyAverage());

DatasetRow df = spark.read().json(“examples/src/main/resources/employees.json”);
df.createOrReplaceTempView(“employees”);
df.show();
// +——-+——+
// | name|salary|
// +——-+——+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +——-+——+

DatasetRow result = spark.sql(“SELECT myAverage(salary) as average_salary FROM employees”);
result.show();
// +————–+
// |average_salary|
// +————–+
// | 3750.0|
// +————–+

1.7.2 用户定义聚合函数

强类型数据集的用户定义聚合围绕Aggregator抽象类。

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
import java.io.Serializable; import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Encoder;import org.apache.spark.sql.Encoders;import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.TypedColumn;import org.apache.spark.sql.expressions.Aggregator; public static class Employee implements Serializable {  private String name;  private long salary;   // Constructors, getters, setters... } public static class Average implements Serializable  {  private long sum;  private long count;   // Constructors, getters, setters... } public static class MyAverage extends AggregatorEmployee, Average, Double {  // A zero value for this aggregation. Should satisfy the property that any b + zero = b  public Average zero() {    return new Average(0L, 0L);  }  // Combine two values to produce a new value. For performance, the function may modify `buffer`  // and return it instead of constructing a new object  public Average reduce(Average buffer, Employee employee) {    long newSum = buffer.getSum() + employee.getSalary();    long newCount = buffer.getCount() + 1;    buffer.setSum(newSum);    buffer.setCount(newCount);    return buffer;  }  // Merge two intermediate values  public Average merge(Average b1, Average b2) {    long mergedSum = b1.getSum() + b2.getSum();    long mergedCount = b1.getCount() + b2.getCount();    b1.setSum(mergedSum);    b1.setCount(mergedCount);    return b1;  }  // Transform the output of the reduction  public Double finish(Average reduction) {    return ((double) reduction.getSum()) / reduction.getCount();  }  // Specifies the Encoder for the intermediate value type  public EncoderAverage bufferEncoder() {    return Encoders.bean(Average.class);  }  // Specifies the Encoder for the final output value type  public EncoderDouble outputEncoder() {    return Encoders.DOUBLE();  }} EncoderEmployee employeeEncoder = Encoders.bean(Employee.class);String path = "examples/src/main/resources/employees.json";DatasetEmployee ds = spark.read().json(path).as(employeeEncoder);ds.show();// +-------+------+// |   name|salary|// +-------+------+// |Michael|  3000|// |   Andy|  4500|// | Justin|  3500|// |  Berta|  4000|// +-------+------+ MyAverage myAverage = new MyAverage();// Convert the function to a `TypedColumn` and give it a nameTypedColumnEmployee, Double averageSalary = myAverage.toColumn().name("average_salary");DatasetDouble result = ds.select(averageSalary);result.show();// +--------------+// |average_salary|// +--------------+// |        3750.0|// +--------------+

import java.io.Serializable;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.expressions.Aggregator;

public static class Employee implements Serializable {
private String name;
private long salary;

// Constructors, getters, setters…

}

public static class Average implements Serializable {
private long sum;
private long count;

// Constructors, getters, setters…

}

public static class MyAverage extends AggregatorEmployee, Average, Double {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
public Average zero() {
return new Average(0L, 0L);
}
// Combine two values to produce a new value. For performance, the function may modify buffer
// and return it instead of constructing a new object
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getSalary();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Merge two intermediate values
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
// Transform the output of the reduction
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
// Specifies the Encoder for the intermediate value type
public EncoderAverage bufferEncoder() {
return Encoders.bean(Average.class);
}
// Specifies the Encoder for the final output value type
public EncoderDouble outputEncoder() {
return Encoders.DOUBLE();
}
}

EncoderEmployee employeeEncoder = Encoders.bean(Employee.class);
String path = “examples/src/main/resources/employees.json”;
DatasetEmployee ds = spark.read().json(path).as(employeeEncoder);
ds.show();
// +——-+——+
// | name|salary|
// +——-+——+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +——-+——+

MyAverage myAverage = new MyAverage();
// Convert the function to a TypedColumn and give it a name
TypedColumnEmployee, Double averageSalary = myAverage.toColumn().name(“average_salary”);
DatasetDouble result = ds.select(averageSalary);
result.show();
// +————–+
// |average_salary|
// +————–+
// | 3750.0|
// +————–+

文章内容均来自官方文档的第一节,、]性能维护]()、部分请结合文档查看。

本人花费半年的时间总结的《Java面试指南》已拿腾讯等大厂offer,已开源在github ,欢迎star!

本文GitHub https://github.com/OUYANGSIHAI/JavaInterview 已收录,这是我花了6个月总结的一线大厂Java面试总结,本人已拿大厂offer,欢迎star

原文链接:blog.ouyangsihai.cn >> Spark 从零到开发(五)初识Spark SQL