Skip to content

Commit

Permalink
Fix error when we have multiple different map types
Browse files Browse the repository at this point in the history
  • Loading branch information
aykut-bozkurt committed Jan 16, 2025
1 parent 71da48c commit e37d709
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 28 deletions.
25 changes: 18 additions & 7 deletions src/arrow_parquet/arrow_to_pg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
type_compat::{
fallback_to_text::{reset_fallback_to_text_context, FallbackToText},
geometry::{is_postgis_geometry_type, Geometry},
map::{is_map_type, Map},
map::{is_map_type, reset_map_type_context, Map},
},
};

Expand Down Expand Up @@ -55,6 +55,7 @@ pub(crate) struct ArrowToPgAttributeContext {
needs_cast: bool,
typoid: Oid,
typmod: i32,
map_typoid: Option<Oid>,
is_geometry: bool,
attribute_contexts: Option<Vec<ArrowToPgAttributeContext>>,
attribute_tupledesc: Option<PgTupleDesc<'static>>,
Expand Down Expand Up @@ -82,21 +83,24 @@ impl ArrowToPgAttributeContext {
let is_array = is_array_type(typoid);
let is_composite;
let is_geometry;
let is_map;
let attribute_typoid;
let attribute_field;
let map_typoid;

if is_array {
let element_typoid = array_element_typoid(typoid);

is_composite = is_composite_type(element_typoid);
is_geometry = is_postgis_geometry_type(element_typoid);
is_map = is_map_type(element_typoid);

if is_map {
if is_map_type(element_typoid) {
map_typoid = Some(element_typoid);

let entries_typoid = domain_array_base_elem_typoid(element_typoid);
attribute_typoid = entries_typoid;
} else {
map_typoid = None;

attribute_typoid = element_typoid;
}

Expand All @@ -107,19 +111,21 @@ impl ArrowToPgAttributeContext {
} else {
is_composite = is_composite_type(typoid);
is_geometry = is_postgis_geometry_type(typoid);
is_map = is_map_type(typoid);

if is_map {
if is_map_type(typoid) {
map_typoid = Some(typoid);

let entries_typoid = domain_array_base_elem_typoid(typoid);
attribute_typoid = entries_typoid;
} else {
map_typoid = None;
attribute_typoid = typoid;
}

attribute_field = field.clone();
}

let attribute_tupledesc = if is_composite || is_map {
let attribute_tupledesc = if is_composite || map_typoid.is_some() {
Some(tuple_desc(attribute_typoid, typmod))
} else {
None
Expand Down Expand Up @@ -183,6 +189,7 @@ impl ArrowToPgAttributeContext {
needs_cast,
typoid: attribute_typoid,
typmod,
map_typoid,
is_geometry,
attribute_contexts,
attribute_tupledesc,
Expand Down Expand Up @@ -367,6 +374,8 @@ fn to_pg_nonarray_datum(
)
}
DataType::Map(_, _) => {
reset_map_type_context(attribute_context.map_typoid.expect("missing map typoid"));

to_pg_datum!(MapArray, Map, primitive_array, attribute_context)
}
_ => {
Expand Down Expand Up @@ -525,6 +534,8 @@ fn to_pg_array_datum(
)
}
DataType::Map(_, _) => {
reset_map_type_context(attribute_context.map_typoid.expect("missing map typoid"));

to_pg_datum!(MapArray, Vec<Option<Map>>, list_array, attribute_context)
}
_ => {
Expand Down
32 changes: 21 additions & 11 deletions src/arrow_parquet/pg_to_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
type_compat::{
fallback_to_text::{reset_fallback_to_text_context, FallbackToText},
geometry::{is_postgis_geometry_type, Geometry},
map::{is_map_type, Map},
map::{is_map_type, reset_map_type_context, Map},
pg_arrow_type_conversions::{
extract_precision_and_scale_from_numeric_typmod, should_write_numeric_as_text,
},
Expand Down Expand Up @@ -60,10 +60,10 @@ pub(crate) struct PgToArrowAttributeContext {
attnum: i16,
typoid: Oid,
typmod: i32,
map_typoid: Option<Oid>,
is_array: bool,
is_composite: bool,
is_geometry: bool,
is_map: bool,
attribute_contexts: Option<Vec<PgToArrowAttributeContext>>,
scale: Option<u32>,
precision: Option<u32>,
Expand All @@ -80,21 +80,24 @@ impl PgToArrowAttributeContext {
let is_array = is_array_type(typoid);
let is_composite;
let is_geometry;
let is_map;
let attribute_typoid;
let attribute_field;
let map_typoid;

if is_array {
let element_typoid = array_element_typoid(typoid);

is_composite = is_composite_type(element_typoid);
is_geometry = is_postgis_geometry_type(element_typoid);
is_map = is_map_type(element_typoid);

if is_map {
if is_map_type(element_typoid) {
map_typoid = Some(element_typoid);

let entries_typoid = domain_array_base_elem_typoid(element_typoid);
attribute_typoid = entries_typoid;
} else {
map_typoid = None;

attribute_typoid = element_typoid;
}

Expand All @@ -105,19 +108,22 @@ impl PgToArrowAttributeContext {
} else {
is_composite = is_composite_type(typoid);
is_geometry = is_postgis_geometry_type(typoid);
is_map = is_map_type(typoid);

if is_map {
if is_map_type(typoid) {
map_typoid = Some(typoid);

let entries_typoid = domain_array_base_elem_typoid(typoid);
attribute_typoid = entries_typoid;
} else {
map_typoid = None;

attribute_typoid = typoid;
}

attribute_field = field.clone();
}

let attribute_tupledesc = if is_composite || is_map {
let attribute_tupledesc = if is_composite || map_typoid.is_some() {
Some(tuple_desc(attribute_typoid, typmod))
} else {
None
Expand Down Expand Up @@ -158,10 +164,10 @@ impl PgToArrowAttributeContext {
attnum,
typoid: attribute_typoid,
typmod,
map_typoid,
is_array,
is_composite,
is_geometry,
is_map,
attribute_contexts,
scale,
precision,
Expand Down Expand Up @@ -333,7 +339,9 @@ fn to_arrow_primitive_array(
}

attribute_vals.to_arrow_array(attribute_context)
} else if attribute_context.is_map {
} else if let Some(map_typoid) = attribute_context.map_typoid {
reset_map_type_context(map_typoid);

to_arrow_primitive_array!(Map, tuples, attribute_context)
} else if attribute_context.is_geometry {
to_arrow_primitive_array!(Geometry, tuples, attribute_context)
Expand Down Expand Up @@ -437,7 +445,9 @@ fn to_arrow_list_array(
}

attribute_vals.to_arrow_array(attribute_context)
} else if attribute_context.is_map {
} else if let Some(map_typoid) = attribute_context.map_typoid {
reset_map_type_context(map_typoid);

to_arrow_list_array!(pgrx::Array<Map>, tuples, attribute_context)
} else if attribute_context.is_geometry {
to_arrow_list_array!(pgrx::Array<Geometry>, tuples, attribute_context)
Expand Down
72 changes: 72 additions & 0 deletions src/pgrx_tests/copy_type_roundtrip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,78 @@ mod tests {
}
}

#[pg_test]
fn test_table_with_multiple_maps() {
// Skip the test if crunchy_map extension is not available
if !extension_exists("crunchy_map") {
return;
}

Spi::run("DROP EXTENSION IF EXISTS crunchy_map; CREATE EXTENSION crunchy_map;").unwrap();

Spi::run("SELECT crunchy_map.create('int','text');").unwrap();
Spi::run("SELECT crunchy_map.create('varchar','int');").unwrap();

let create_expected_table = "CREATE TABLE test_expected (a crunchy_map.key_int_val_text, b crunchy_map.key_varchar_val_int);";
Spi::run(create_expected_table).unwrap();

let insert = "INSERT INTO test_expected (a, b) VALUES ('{\"(1,)\",\"(2,myself)\",\"(3,ddd)\"}'::crunchy_map.key_int_val_text, '{\"(a,1)\",\"(b,2)\",\"(c,3)\"}'::crunchy_map.key_varchar_val_int);";
Spi::run(insert).unwrap();

let copy_to = format!("COPY test_expected TO '{LOCAL_TEST_FILE_PATH}'");
Spi::run(&copy_to).unwrap();

let create_result_table = "CREATE TABLE test_result (a crunchy_map.key_int_val_text, b crunchy_map.key_varchar_val_int);";
Spi::run(create_result_table).unwrap();

let copy_from = format!("COPY test_result FROM '{LOCAL_TEST_FILE_PATH}'");
Spi::run(&copy_from).unwrap();

let expected_a = Spi::connect(|client| {
let query = "SELECT (crunchy_map.entries(a)).* from test_expected;";
let tup_table = client.select(query, None, None).unwrap();

let mut results = Vec::new();

for row in tup_table {
let key = row["key"].value::<i32>().unwrap().unwrap();
let val = row["value"].value::<String>().unwrap();
results.push((key, val));
}

results
});

assert_eq!(
expected_a,
vec![
(1, None),
(2, Some("myself".into())),
(3, Some("ddd".into()))
]
);

let expected_b = Spi::connect(|client| {
let query = "SELECT (crunchy_map.entries(b)).* from test_expected;";
let tup_table = client.select(query, None, None).unwrap();

let mut results = Vec::new();

for row in tup_table {
let key = row["key"].value::<String>().unwrap().unwrap();
let val = row["value"].value::<i32>().unwrap().unwrap();
results.push((key, val));
}

results
});

assert_eq!(
expected_b,
vec![("a".into(), 1), ("b".into(), 2), ("c".into(), 3)]
);
}

#[pg_test]
#[should_panic(expected = "MapArray entries cannot contain nulls")]
fn test_map_null_entries() {
Expand Down
18 changes: 8 additions & 10 deletions src/type_compat/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ pub(crate) fn reset_map_context() {
};
}

pub(crate) fn reset_map_type_context(map_type_oid: Oid) {
get_map_context()
.map_type_context
.set_current_map_type_oid(map_type_oid);
}

pub(crate) fn is_map_type(typoid: Oid) -> bool {
let map_context = get_map_context();

Expand All @@ -58,7 +64,7 @@ pub(crate) fn is_map_type(typoid: Oid) -> bool {
return false;
}

let found_typoid = unsafe {
let map_typoid = unsafe {
GetSysCacheOid(
TYPEOID as _,
Anum_pg_type_oid as _,
Expand All @@ -69,15 +75,7 @@ pub(crate) fn is_map_type(typoid: Oid) -> bool {
)
};

let is_map = found_typoid != InvalidOid;

if is_map {
map_context
.map_type_context
.set_current_map_type_oid(typoid);
}

is_map
map_typoid != InvalidOid
}

#[derive(Debug, PartialEq, Clone)]
Expand Down

0 comments on commit e37d709

Please sign in to comment.