Skip to content

Commit

Permalink
refactor arrow to pg context
Browse files Browse the repository at this point in the history
  • Loading branch information
aykut-bozkurt committed Jan 21, 2025
1 parent ab5e46e commit bf97833
Show file tree
Hide file tree
Showing 11 changed files with 420 additions and 332 deletions.
309 changes: 44 additions & 265 deletions src/arrow_parquet/arrow_to_pg.rs

Large diffs are not rendered by default.

19 changes: 6 additions & 13 deletions src/arrow_parquet/arrow_to_pg/composite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,17 @@ impl<'a> ArrowArrayToPgType<PgHeapTuple<'a, AllocatedByRust>> for StructArray {

let mut datums = vec![];

for attribute_context in context
.attribute_contexts
.as_ref()
.expect("each attribute of the tuple should have a context")
{
for attribute_context in context.attribute_contexts() {
let column_data = self
.column_by_name(&attribute_context.name)
.unwrap_or_else(|| panic!("column {} not found", &attribute_context.name));
.column_by_name(attribute_context.name())
.unwrap_or_else(|| panic!("column {} not found", &attribute_context.name()));

let datum = to_pg_datum(column_data.into_data(), attribute_context);

datums.push(datum);
}

let tupledesc = context
.attribute_tupledesc
.as_ref()
.expect("Expected attribute tupledesc");
let tupledesc = context.tupledesc();

Some(
unsafe { PgHeapTuple::from_datums(tupledesc.clone(), datums) }.unwrap_or_else(|e| {
Expand All @@ -46,15 +39,15 @@ impl<'a> ArrowArrayToPgType<PgHeapTuple<'a, AllocatedByRust>> for StructArray {
impl<'a> ArrowArrayToPgType<Vec<Option<PgHeapTuple<'a, AllocatedByRust>>>> for StructArray {
fn to_pg_type(
self,
context: &ArrowToPgAttributeContext,
element_context: &ArrowToPgAttributeContext,
) -> Option<Vec<Option<PgHeapTuple<'a, AllocatedByRust>>>> {
let len = self.len();
let mut values = Vec::with_capacity(len);

for i in 0..len {
let tuple = self.slice(i, 1);

let tuple = tuple.to_pg_type(&context.clone());
let tuple = tuple.to_pg_type(element_context);

values.push(tuple);
}
Expand Down
320 changes: 320 additions & 0 deletions src/arrow_parquet/arrow_to_pg/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
use std::ops::Deref;

use arrow_schema::{DataType, FieldRef, Fields};
use pgrx::{
pg_sys::{FormData_pg_attribute, Oid, NUMERICOID},
PgTupleDesc,
};

use crate::type_compat::pg_arrow_type_conversions::extract_precision_and_scale_from_numeric_typmod;

use super::{
array_element_typoid, collect_attributes_for, domain_array_base_elem_type, is_array_type,
is_composite_type, is_map_type, is_postgis_geometry_type, tuple_desc, CollectAttributesFor,
};

// ArrowToPgAttributeContext contains the information needed to convert an Arrow array
// to a PostgreSQL attribute.
#[derive(Clone)]
pub(crate) struct ArrowToPgAttributeContext {
// common info for all types
name: String,
data_type: DataType,
needs_cast: bool,
typoid: Oid,
typmod: i32,

// type-specific info
type_context: ArrowToPgAttributeTypeContext,
}

impl Deref for ArrowToPgAttributeContext {
type Target = ArrowToPgAttributeTypeContext;

fn deref(&self) -> &Self::Target {
&self.type_context
}
}

impl ArrowToPgAttributeContext {
pub(crate) fn new(
name: &str,
typoid: Oid,
typmod: i32,
field: FieldRef,
cast_to_type: Option<DataType>,
) -> Self {
let needs_cast = cast_to_type.is_some();

let data_type = if let Some(cast_to_type) = &cast_to_type {
cast_to_type.clone()
} else {
field.data_type().clone()
};

let type_context = ArrowToPgAttributeTypeContext::new(typoid, typmod, &data_type);

Self {
name: name.to_string(),
data_type,
needs_cast,
typoid,
typmod,
type_context,
}
}

pub(crate) fn typoid(&self) -> Oid {
self.typoid
}

pub(crate) fn typmod(&self) -> i32 {
self.typmod
}

pub(crate) fn name(&self) -> &str {
&self.name
}

pub(crate) fn needs_cast(&self) -> bool {
self.needs_cast
}

pub(crate) fn data_type(&self) -> &DataType {
&self.data_type
}

pub(crate) fn timezone(&self) -> &str {
let timezone = match &self.type_context {
ArrowToPgAttributeTypeContext::Primitive { timezone, .. } => timezone.as_ref(),
_ => None,

Check warning on line 90 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L90

Added line #L90 was not covered by tests
};

timezone.unwrap_or_else(|| panic!("missing timezone in context"))
}
}

// ArrowToPgAttributeTypeContext contains type specific information needed to
// convert an Arrow array to a PostgreSQL attribute.
#[derive(Clone)]
pub(crate) enum ArrowToPgAttributeTypeContext {
Primitive {
is_geometry: bool,
precision: Option<u32>,
scale: Option<u32>,
timezone: Option<String>,
},
Array {
element_context: Box<ArrowToPgAttributeContext>,
},
Composite {
tupledesc: PgTupleDesc<'static>,
attribute_contexts: Vec<ArrowToPgAttributeContext>,
},
Map {
entries_context: Box<ArrowToPgAttributeContext>,
},
}

impl ArrowToPgAttributeTypeContext {
// constructors
fn new(typoid: Oid, typmod: i32, data_type: &DataType) -> Self {
if is_array_type(typoid) {
Self::new_array(typoid, typmod, data_type)
} else if is_composite_type(typoid) {
Self::new_composite(typoid, typmod, data_type)
} else if is_map_type(typoid) {
Self::new_map(typoid, data_type)

Check warning on line 127 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L127

Added line #L127 was not covered by tests
} else {
Self::new_primitive(typoid, typmod, data_type)
}
}

fn new_primitive(typoid: Oid, typmod: i32, data_type: &DataType) -> Self {
let precision;
let scale;
if typoid == NUMERICOID {
let (p, s) = extract_precision_and_scale_from_numeric_typmod(typmod);
precision = Some(p);
scale = Some(s);
} else {
precision = None;
scale = None;
}

let is_geometry = is_postgis_geometry_type(typoid);

let timezone = match &data_type {
DataType::Timestamp(_, Some(timezone)) => Some(timezone.to_string()),
_ => None,
};

Self::Primitive {
is_geometry,
precision,
scale,
timezone,
}
}

fn new_array(typoid: Oid, typmod: i32, data_type: &DataType) -> Self {
let element_typoid = array_element_typoid(typoid);
let element_typmod = typmod;

let element_field = match data_type {
DataType::List(field) => field.clone(),
_ => unreachable!(),

Check warning on line 166 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L166

Added line #L166 was not covered by tests
};

let element_context = Box::new(ArrowToPgAttributeContext::new(
element_field.name(),
element_typoid,
element_typmod,
element_field.clone(),
None,
));

Self::Array { element_context }
}

fn new_composite(typoid: Oid, typmod: i32, data_type: &DataType) -> Self {
let tupledesc = tuple_desc(typoid, typmod);
let fields = match data_type {
arrow::datatypes::DataType::Struct(fields) => fields.clone(),
_ => unreachable!(),

Check warning on line 184 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L184

Added line #L184 was not covered by tests
};

let attributes = collect_attributes_for(CollectAttributesFor::Other, &tupledesc);

// we only cast the top-level attributes, which already covers the nested attributes
let cast_to_types = None;

let attribute_contexts =
collect_arrow_to_pg_attribute_contexts(&attributes, &fields, cast_to_types);

Self::Composite {
tupledesc,
attribute_contexts,
}
}

fn new_map(typoid: Oid, data_type: &DataType) -> Self {
let (entries_typoid, entries_typmod) = domain_array_base_elem_type(typoid);

Check warning on line 202 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L201-L202

Added lines #L201 - L202 were not covered by tests

let entries_field = match data_type {
arrow::datatypes::DataType::Map(entries_field, _) => entries_field.clone(),
_ => unreachable!(),

Check warning on line 206 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L204-L206

Added lines #L204 - L206 were not covered by tests
};

let entries_context = Box::new(ArrowToPgAttributeContext::new(
entries_field.name(),
entries_typoid,
entries_typmod,
entries_field.clone(),
None,
));

Self::Map { entries_context }
}

Check warning on line 218 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L209-L218

Added lines #L209 - L218 were not covered by tests

// primitive type methods
pub(crate) fn precision(&self) -> u32 {
let precision = match self {
Self::Primitive { precision, .. } => *precision,
_ => None,

Check warning on line 224 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L224

Added line #L224 was not covered by tests
};

precision.unwrap_or_else(|| panic!("missing precision in context"))
}

pub(crate) fn scale(&self) -> u32 {
let scale = match self {
Self::Primitive { scale, .. } => *scale,
_ => None,

Check warning on line 233 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L233

Added line #L233 was not covered by tests
};

scale.unwrap_or_else(|| panic!("missing scale in context"))
}

// composite type methods
pub(crate) fn tupledesc(&self) -> PgTupleDesc<'static> {
match self {
Self::Composite { tupledesc, .. } => tupledesc.clone(),
_ => panic!("missing tupledesc in context"),

Check warning on line 243 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L243

Added line #L243 was not covered by tests
}
}

pub(crate) fn attribute_contexts(&self) -> &Vec<ArrowToPgAttributeContext> {
match self {
Self::Composite {
attribute_contexts, ..
} => attribute_contexts,
_ => panic!("missing attribute contexts in context"),

Check warning on line 252 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L252

Added line #L252 was not covered by tests
}
}

// map type methods
pub(crate) fn entries_context(&self) -> &ArrowToPgAttributeContext {
match self {
Self::Map { entries_context } => entries_context,
_ => panic!("missing entries context in context"),

Check warning on line 260 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L257-L260

Added lines #L257 - L260 were not covered by tests
}
}

Check warning on line 262 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L262

Added line #L262 was not covered by tests

// array type methods
pub(crate) fn element_context(&self) -> &ArrowToPgAttributeContext {
match self {
Self::Array {
element_context, ..
} => element_context,
_ => panic!("not a context for an array type"),

Check warning on line 270 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L270

Added line #L270 was not covered by tests
}
}

// type checks
pub(crate) fn is_geometry(&self) -> bool {
match &self {
ArrowToPgAttributeTypeContext::Primitive { is_geometry, .. } => *is_geometry,
_ => false,

Check warning on line 278 in src/arrow_parquet/arrow_to_pg/context.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/context.rs#L278

Added line #L278 was not covered by tests
}
}
}

pub(crate) fn collect_arrow_to_pg_attribute_contexts(
attributes: &[FormData_pg_attribute],
fields: &Fields,
cast_to_types: Option<Vec<Option<DataType>>>,
) -> Vec<ArrowToPgAttributeContext> {
let mut attribute_contexts = vec![];

for (idx, attribute) in attributes.iter().enumerate() {
let attribute_name = attribute.name();
let attribute_typoid = attribute.type_oid().value();
let attribute_typmod = attribute.type_mod();

let field = fields
.iter()
.find(|field| field.name() == attribute_name)
.unwrap_or_else(|| panic!("failed to find field {}", attribute_name))
.clone();

let cast_to_type = if let Some(cast_to_types) = cast_to_types.as_ref() {
debug_assert!(cast_to_types.len() == attributes.len());
cast_to_types.get(idx).cloned().expect("cast_to_type null")
} else {
None
};

let attribute_context = ArrowToPgAttributeContext::new(
attribute_name,
attribute_typoid,
attribute_typmod,
field,
cast_to_type,
);

attribute_contexts.push(attribute_context);
}

attribute_contexts
}
9 changes: 6 additions & 3 deletions src/arrow_parquet/arrow_to_pg/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ impl<'a> ArrowArrayToPgType<Map<'a>> for MapArray {
let entries_array = self.value(0);

let entries: Option<Vec<Option<PgHeapTuple<AllocatedByRust>>>> =
entries_array.to_pg_type(context);
entries_array.to_pg_type(context.entries_context());

Check warning on line 17 in src/arrow_parquet/arrow_to_pg/map.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/map.rs#L17

Added line #L17 was not covered by tests

if let Some(entries) = entries {
let entries_datum = entries.into_datum();
Expand All @@ -38,13 +38,16 @@ impl<'a> ArrowArrayToPgType<Map<'a>> for MapArray {

// crunchy_map.key_<type1>_val_<type2>[]
impl<'a> ArrowArrayToPgType<Vec<Option<Map<'a>>>> for MapArray {
fn to_pg_type(self, context: &ArrowToPgAttributeContext) -> Option<Vec<Option<Map<'a>>>> {
fn to_pg_type(
self,
element_context: &ArrowToPgAttributeContext,
) -> Option<Vec<Option<Map<'a>>>> {

Check warning on line 44 in src/arrow_parquet/arrow_to_pg/map.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/map.rs#L41-L44

Added lines #L41 - L44 were not covered by tests
let mut maps = vec![];

for entries_array in self.iter() {
if let Some(entries_array) = entries_array {
let entries: Option<Vec<Option<PgHeapTuple<AllocatedByRust>>>> =
entries_array.to_pg_type(context);
entries_array.to_pg_type(element_context.entries_context());

Check warning on line 50 in src/arrow_parquet/arrow_to_pg/map.rs

View check run for this annotation

Codecov / codecov/patch

src/arrow_parquet/arrow_to_pg/map.rs#L50

Added line #L50 was not covered by tests

if let Some(entries) = entries {
let entries_datum = entries.into_datum();
Expand Down
Loading

0 comments on commit bf97833

Please sign in to comment.