diff --git a/compiler/extension/fory_options.proto b/compiler/extension/fory_options.proto index 62ae968020..9d1cb8b7f6 100644 --- a/compiler/extension/fory_options.proto +++ b/compiler/extension/fory_options.proto @@ -179,10 +179,11 @@ message ForyFieldOptions { optional bool weak_ref = 4; // Generate thread-safe Rust pointer carriers for ref fields. - // When true, Rust codegen uses Arc/ArcWeak instead of Rc/RcWeak. + // Rust codegen uses Arc/ArcWeak by default. Set this to false to generate + // Rc/RcWeak for ref fields that must stay single-threaded. // This does not change the wire format and does not make the referenced // value itself thread-safe. - // Default: false + // Default: true optional bool thread_safe_pointer = 5; } diff --git a/compiler/fory_compiler/generators/rust.py b/compiler/fory_compiler/generators/rust.py index 652a9cd6eb..a72642e66f 100644 --- a/compiler/fory_compiler/generators/rust.py +++ b/compiler/fory_compiler/generators/rust.py @@ -34,6 +34,7 @@ ArrayType, MapType, Schema, + thread_safe_pointer_enabled, ) from fory_compiler.ir.types import PrimitiveKind @@ -43,6 +44,7 @@ class RustGenerator(BaseGenerator): language_name = "rust" file_extension = ".rs" + RUST_ANY_TYPE = "::std::sync::Arc" # Mapping from FDL primitive types to Rust types PRIMITIVE_MAP = { @@ -62,7 +64,7 @@ class RustGenerator(BaseGenerator): PrimitiveKind.STRING: "::std::string::String", PrimitiveKind.BYTES: "::std::vec::Vec", PrimitiveKind.DECIMAL: "::fory::Decimal", - PrimitiveKind.ANY: "::std::boxed::Box", + PrimitiveKind.ANY: RUST_ANY_TYPE, } FORY_TEMPORAL_MAP = { @@ -1012,12 +1014,12 @@ def generate_type( element_optional: bool = False, element_ref: bool = False, parent_stack: Optional[List[Message]] = None, - pointer_type: str = "::std::rc::Rc", + pointer_type: str = "::std::sync::Arc", ) -> str: """Generate Rust type string.""" if isinstance(field_type, PrimitiveType): if field_type.kind == PrimitiveKind.ANY: - return "::std::boxed::Box" + return self.RUST_ANY_TYPE base_type = self.primitive_type_name(field_type.kind) if nullable: return f"::std::option::Option<{base_type}>" @@ -1152,7 +1154,7 @@ def get_field_pointer_type(self, field: Field) -> str: def get_pointer_type(self, ref_options: dict, weak_ref: bool = False) -> str: """Determine pointer type for ref tracking based on field options.""" - if ref_options.get("thread_safe_pointer") is True: + if thread_safe_pointer_enabled(ref_options): return "::fory::ArcWeak" if weak_ref else "::std::sync::Arc" return "::fory::RcWeak" if weak_ref else "::std::rc::Rc" diff --git a/compiler/fory_compiler/ir/ast.py b/compiler/fory_compiler/ir/ast.py index f2718e84fa..24360fa019 100644 --- a/compiler/fory_compiler/ir/ast.py +++ b/compiler/fory_compiler/ir/ast.py @@ -22,6 +22,13 @@ from fory_compiler.ir.types import PrimitiveKind +THREAD_SAFE_POINTER_DEFAULT = True + + +def thread_safe_pointer_enabled(ref_options: dict) -> bool: + """Return the effective Rust pointer-carrier default for ref options.""" + return ref_options.get("thread_safe_pointer", THREAD_SAFE_POINTER_DEFAULT) is True + @dataclass(frozen=True) class SourceLocation: diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index 62f6350f4f..f15eb39712 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -216,8 +216,8 @@ def test_rust_nested_container_ref_uses_correct_pointer_type(): } message Request { - list> groups = 1; - map> nodes = 2; + list> groups = 1; + map> nodes = 2; } """ ) @@ -497,7 +497,7 @@ def test_generated_code_map_types_equivalent(): assert "SharedWeak" in cpp_output -def test_rust_generated_ref_pointer_default_and_thread_safe_option(): +def test_rust_generated_ref_pointer_default_and_opt_out(): schema = parse_fdl( dedent( """ @@ -509,25 +509,24 @@ def test_rust_generated_ref_pointer_default_and_thread_safe_option(): message Holder { ref Node default_ref = 1; - ref(thread_safe=true) Node thread_safe_ref = 2; + ref(thread_safe=false) Node rc_ref = 2; ref(weak=true) Node default_weak_ref = 3; - ref(weak=true, thread_safe=true) Node thread_safe_weak_ref = 4; + ref(weak=true, thread_safe=false) Node rc_weak_ref = 4; list default_ref_list = 5; - list thread_safe_ref_list = 6; + list rc_ref_list = 6; } """ ) ) rust_output = render_files(generate_files(schema, RustGenerator)) - assert "pub default_ref: ::std::rc::Rc," in rust_output - assert "pub thread_safe_ref: ::std::sync::Arc," in rust_output - assert "pub default_weak_ref: ::fory::RcWeak," in rust_output - assert "pub thread_safe_weak_ref: ::fory::ArcWeak," in rust_output - assert "pub default_ref_list: ::std::vec::Vec<::std::rc::Rc>," in rust_output + assert "pub default_ref: ::std::sync::Arc," in rust_output + assert "pub rc_ref: ::std::rc::Rc," in rust_output + assert "pub default_weak_ref: ::fory::ArcWeak," in rust_output + assert "pub rc_weak_ref: ::fory::RcWeak," in rust_output assert ( - "pub thread_safe_ref_list: ::std::vec::Vec<::std::sync::Arc>," - in rust_output + "pub default_ref_list: ::std::vec::Vec<::std::sync::Arc>," in rust_output ) + assert "pub rc_ref_list: ::std::vec::Vec<::std::rc::Rc>," in rust_output def test_generated_code_nested_messages_equivalent(): @@ -779,7 +778,7 @@ def test_generated_code_tree_ref_options_equivalent(): assert_all_languages_equal(schemas) rust_output = render_files(generate_files(schemas["fdl"], RustGenerator)) - assert "RcWeak" in rust_output + assert "ArcWeak" in rust_output assert "#[derive(::fory::ForyStruct, Clone, PartialEq, Eq, Default)]" in rust_output cpp_output = render_files(generate_files(schemas["fdl"], CppGenerator)) @@ -1153,8 +1152,11 @@ def test_rust_generated_code_uses_absolute_paths(): "pub labels: ::std::collections::HashMap<::std::string::String, ::std::string::String>," in rust_output ) - assert "pub payload: ::std::boxed::Box," in rust_output - assert "pub parent: ::fory::RcWeak," in rust_output + assert ( + "pub payload: ::std::sync::Arc," + in rust_output + ) + assert "pub parent: ::fory::ArcWeak," in rust_output assert "pub fn register_types(fory: &mut ::fory::Fory)" in rust_output assert "static FORY: ::std::sync::OnceLock<::fory::Fory>" in rust_output diff --git a/compiler/fory_compiler/tests/test_weak_ref.py b/compiler/fory_compiler/tests/test_weak_ref.py index 35fb1f64c5..a90cb8382b 100644 --- a/compiler/fory_compiler/tests/test_weak_ref.py +++ b/compiler/fory_compiler/tests/test_weak_ref.py @@ -122,7 +122,7 @@ def test_weak_ref_requires_repeated_ref(): ) -def test_list_and_map_ref_options_with_thread_safe(): +def test_list_and_map_ref_options_preserve_explicit_opt_out(): source = """ message Foo { int32 id = 1; @@ -134,8 +134,8 @@ def test_list_and_map_ref_options_with_thread_safe(): message Holder { list foos = 1; - list bars = 2; - map bar_map = 3; + list bars = 2; + map bar_map = 3; } """ schema = parse_schema(source) @@ -152,8 +152,8 @@ def test_list_and_map_ref_options_with_thread_safe(): assert bars.element_ref is True assert bars.element_ref_options.get("weak_ref") is True - assert bars.element_ref_options.get("thread_safe_pointer") is True + assert bars.element_ref_options.get("thread_safe_pointer") is False assert bar_map.field_type.value_ref is True assert bar_map.field_type.value_ref_options.get("weak_ref") is True - assert bar_map.field_type.value_ref_options.get("thread_safe_pointer") is True + assert bar_map.field_type.value_ref_options.get("thread_safe_pointer") is False diff --git a/docs/compiler/flatbuffers-idl.md b/docs/compiler/flatbuffers-idl.md index 71a249f03e..f7aa298c24 100644 --- a/docs/compiler/flatbuffers-idl.md +++ b/docs/compiler/flatbuffers-idl.md @@ -156,20 +156,20 @@ FlatBuffers metadata attributes use `key:value`. For Fory-specific options, use ### Supported Field Attributes -| FlatBuffers Attribute | Effect in Fory | -| ------------------------------- | -------------------------------------------------------------------------------- | -| `fory_ref:true` | Enable reference tracking for the field | -| `fory_nullable:true` | Mark field optional/nullable | -| `fory_weak_ref:true` | Enable weak reference semantics and implies `ref` | -| `fory_thread_safe_pointer:true` | For ref fields, select Rust `Arc`/`ArcWeak` instead of the default `Rc`/`RcWeak` | +| FlatBuffers Attribute | Effect in Fory | +| -------------------------------- | -------------------------------------------------------------------------------- | +| `fory_ref:true` | Enable reference tracking for the field | +| `fory_nullable:true` | Mark field optional/nullable | +| `fory_weak_ref:true` | Enable weak reference semantics and implies `ref` | +| `fory_thread_safe_pointer:false` | For ref fields, select Rust `Rc`/`RcWeak` instead of the default `Arc`/`ArcWeak` | Semantics: - `fory_weak_ref:true` implies `ref`. -- `fory_thread_safe_pointer` defaults to `false`, only takes effect when the field +- `fory_thread_safe_pointer` defaults to `true`, only takes effect when the field is ref-tracked, and does not change the wire format. -- In Rust codegen, `fory_weak_ref:true` uses `RcWeak` by default and switches to - `ArcWeak` only when `fory_thread_safe_pointer:true` is also set. +- In Rust codegen, `fory_weak_ref:true` uses `ArcWeak` by default and switches to + `RcWeak` only when `fory_thread_safe_pointer:false` is set. - For list fields, `fory_ref:true` applies to list elements. Example: @@ -178,7 +178,7 @@ Example: table Node { parent: Node (fory_weak_ref: true); children: [Node] (fory_ref: true); - cached: Node (fory_ref: true, fory_thread_safe_pointer: true); + local: Node (fory_ref: true, fory_thread_safe_pointer: false); } ``` diff --git a/docs/compiler/protobuf-idl.md b/docs/compiler/protobuf-idl.md index 3a13fb8ff4..59f4c63a77 100644 --- a/docs/compiler/protobuf-idl.md +++ b/docs/compiler/protobuf-idl.md @@ -232,14 +232,14 @@ message TreeNode { ### Field-Level Options -| Option | Type | Description | -| ---------------------------- | ------ | --------------------------------------------------------------------------- | -| `(fory).ref` | bool | Enable reference tracking for this field | -| `(fory).nullable` | bool | Treat field as nullable (`optional`) | -| `(fory).weak_ref` | bool | Generate weak pointer semantics (C++/Rust codegen) | -| `(fory).thread_safe_pointer` | bool | Use Rust `Arc`/`ArcWeak` for ref fields; default `false` uses `Rc`/`RcWeak` | -| `(fory).deprecated` | bool | Mark field as deprecated | -| `(fory).type` | string | Primitive override for tagged 64-bit integer encoding | +| Option | Type | Description | +| ---------------------------- | ------ | ---------------------------------------------------------------------------------------------------- | +| `(fory).ref` | bool | Enable reference tracking for this field | +| `(fory).nullable` | bool | Treat field as nullable (`optional`) | +| `(fory).weak_ref` | bool | Generate weak pointer semantics (C++/Rust codegen) | +| `(fory).thread_safe_pointer` | bool | Rust ref carrier selection; default `true` uses `Arc`/`ArcWeak`, explicit `false` uses `Rc`/`RcWeak` | +| `(fory).deprecated` | bool | Mark field as deprecated | +| `(fory).type` | string | Primitive override for tagged 64-bit integer encoding | Reference option behavior: @@ -247,19 +247,20 @@ Reference option behavior: - For `repeated` fields, `(fory).ref = true` applies to list elements. - For `map` fields, `(fory).ref = true` applies to map values. - `weak_ref` and `thread_safe_pointer` are codegen hints for C++/Rust. -- `thread_safe_pointer` defaults to `false`; it changes only the generated Rust +- `thread_safe_pointer` defaults to `true`; it changes only the generated Rust pointer carrier and does not change the wire format. -- In Rust codegen, `(fory).weak_ref = true` uses `RcWeak` by default and switches - to `ArcWeak` only when `(fory).thread_safe_pointer = true`. +- In Rust codegen, `(fory).weak_ref = true` uses `ArcWeak` by default and + switches to `RcWeak` only when `(fory).thread_safe_pointer = false`. ### Option Examples by Shape ```protobuf message Graph { - Node root = 1 [(fory).ref = true, (fory).thread_safe_pointer = true]; + Node root = 1 [(fory).ref = true]; repeated Node nodes = 2 [(fory).ref = true]; map cache = 3 [(fory).ref = true]; Node parent = 4 [(fory).weak_ref = true]; + Node local = 5 [(fory).ref = true, (fory).thread_safe_pointer = false]; } ``` diff --git a/docs/compiler/schema-idl.md b/docs/compiler/schema-idl.md index 5b99fc892a..89705883b0 100644 --- a/docs/compiler/schema-idl.md +++ b/docs/compiler/schema-idl.md @@ -1052,7 +1052,7 @@ message Node { | Java | `Node parent` | `Node parent` with `@Ref` | | Python | `parent: Node` | `parent: Node = pyfory.field(ref=True)` | | Go | `Parent Node` | `Parent *Node` with `fory:"ref"` | -| Rust | `parent: Node` | `parent: Rc` | +| Rust | `parent: Node` | `parent: Arc` | | C++ | `Node parent` | `std::shared_ptr parent` | | C# | `Node parent` | `Node? parent` with runtime ref tracking | | JavaScript/TypeScript | `parent: Node` | `parent: Node` (no ref distinction) | @@ -1061,21 +1061,21 @@ message Node { | Scala | `parent: Node` | `@Ref parent: Node` | | Kotlin | `parent: Node` | `@Ref parent: Node?` | -Rust uses `Rc` and `RcWeak` by default for ref-tracked fields. Use -`ref(thread_safe=true)` when the generated Rust type must use `Arc` or -`ArcWeak` for cross-thread shared ownership. This setting is a Rust codegen -carrier choice; it does not change the wire format or make the referenced value -itself thread-safe. For protobuf option syntax, see +Rust uses `Arc` and `ArcWeak` by default for ref-tracked fields. Use +`ref(thread_safe=false)` when a generated Rust type must use single-threaded +`Rc` or `RcWeak` carriers. This setting is a Rust codegen carrier choice; it +does not change the wire format or make the referenced value itself +thread-safe. For protobuf option syntax, see [Protocol Buffers IDL Support](protobuf-idl.md#field-level-options). Rust pointer carrier mapping: -| Fory IDL | Rust type | -| ---------------------------------------------- | --------------- | -| `ref Node parent` | `Rc` | -| `ref(thread_safe=true) Node parent` | `Arc` | -| `ref(weak=true) Node parent` | `RcWeak` | -| `ref(weak=true, thread_safe=true) Node parent` | `ArcWeak` | +| Fory IDL | Rust type | +| ----------------------------------------------- | --------------- | +| `ref Node parent` | `Arc` | +| `ref(thread_safe=false) Node parent` | `Rc` | +| `ref(weak=true) Node parent` | `ArcWeak` | +| `ref(weak=true, thread_safe=false) Node parent` | `RcWeak` | #### `list` @@ -1127,10 +1127,11 @@ accepted as an alias for `list`. | ----------------------- | ---------------------------------- | --------------------- | ----------------------- | --------------------- | ----------------------------------------- | ------------------------------------------------------------- | ---------------------- | | `optional list` | `@Nullable List` | `Optional[List[str]]` | `[]string` + `nullable` | `Option>` | `std::optional>` | `List?` | `Option[List[String]]` | | `list` | `List` (nullable elements) | `List[Optional[str]]` | `[]*string` | `Vec>` | `std::vector>` | `List` | `List[Option[String]]` | -| `list` | `List<@Ref User>` | `List[User]` | `[]*User` + `ref=false` | `Vec>` | `std::vector>` | `List` + `@ListField(element: DeclaredType(ref: true))` | `List[User @Ref]` | +| `list` | `List<@Ref User>` | `List[User]` | `[]*User` + `ref=false` | `Vec>` | `std::vector>` | `List` + `@ListField(element: DeclaredType(ref: true))` | `List[User @Ref]` | -Use `ref(thread_safe=true)` in Fory IDL (or `[(fory).thread_safe_pointer = true]` in protobuf) -to generate `Arc` instead of `Rc` in Rust. +Use `ref(thread_safe=false)` in Fory IDL (or +`[(fory).thread_safe_pointer = false]` in protobuf) to generate `Rc` instead of +`Arc` in Rust. ## Field Numbers @@ -1339,15 +1340,15 @@ Underscore spellings for integer encoding are not FDL type names. #### Any -| Language | Type | Notes | -| --------------------- | -------------- | -------------------- | -| Java | `Object` | Runtime type written | -| Python | `Any` | Runtime type written | -| Go | `any` | Runtime type written | -| Rust | `Box` | Runtime type written | -| C++ | `std::any` | Runtime type written | -| JavaScript/TypeScript | `any` | Runtime type written | -| Dart | `Object?` | Runtime type written | +| Language | Type | Notes | +| --------------------- | ---------------------------- | -------------------- | +| Java | `Object` | Runtime type written | +| Python | `Any` | Runtime type written | +| Go | `any` | Runtime type written | +| Rust | `Arc` | Runtime type written | +| C++ | `std::any` | Runtime type written | +| JavaScript/TypeScript | `any` | Runtime type written | +| Dart | `Object?` | Runtime type written | **Example:** @@ -1369,15 +1370,15 @@ message Envelope [id=122] { **Generated Code (`Envelope.payload`):** -| Language | Generated Field Type | -| --------------------- | ----------------------- | -| Java | `Object payload` | -| Python | `payload: Any` | -| Go | `Payload any` | -| Rust | `payload: Box` | -| C++ | `std::any payload` | -| JavaScript/TypeScript | `payload: any` | -| Dart | `Object? payload` | +| Language | Generated Field Type | +| --------------------- | ------------------------------------- | +| Java | `Object payload` | +| Python | `payload: Any` | +| Go | `Payload any` | +| Rust | `payload: Arc` | +| C++ | `std::any payload` | +| JavaScript/TypeScript | `payload: any` | +| Dart | `Object? payload` | **Notes:** diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 83fc6e538c..d140f0357c 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -80,7 +80,7 @@ let fory = Fory::builder().xlang(true).max_dyn_depth(10).build(); // Allow up to **Protected types:** -- `Box`, `Rc`, `Arc` +- `Box`, `Rc`, `Arc` - `Box`, `Rc`, `Arc` (trait objects) - `RcWeak`, `ArcWeak` - Collection types (Vec, HashMap, HashSet) diff --git a/docs/guide/rust/native-serialization.md b/docs/guide/rust/native-serialization.md index cba76f7010..e2bf6e74f4 100644 --- a/docs/guide/rust/native-serialization.md +++ b/docs/guide/rust/native-serialization.md @@ -100,7 +100,7 @@ Native serialization owns the Rust-specific object surface: - `Box`, `Rc`, `Arc`, `RcWeak`, and `ArcWeak`. - `RefCell` and `Mutex`. - Trait objects such as `Box`, `Rc`, and `Arc`. -- Runtime type dispatch with `Rc` and `Arc`. +- Runtime type dispatch with `Rc` and `Arc`. - Date and time carriers, including optional `chrono` support. Use [Basic Serialization](basic-serialization.md), [References](references.md), and diff --git a/docs/guide/rust/polymorphism.md b/docs/guide/rust/polymorphism.md index 1f21d9eacd..fa87735edd 100644 --- a/docs/guide/rust/polymorphism.md +++ b/docs/guide/rust/polymorphism.md @@ -85,7 +85,8 @@ assert_eq!(decoded.star_animal.speak(), "Woof!"); ## Serializing dyn Any Trait Objects -Apache Fory™ supports serializing `Rc` and `Arc` for runtime type dispatch: +Apache Fory™ supports serializing `Rc` and +`Arc` for runtime type dispatch: **Key points:** @@ -98,14 +99,11 @@ Apache Fory™ supports serializing `Rc` and `Arc` for runtime use std::rc::Rc; use std::any::Any; -let dog_rc: Rc = Rc::new(Dog { +let dog_any: Rc = Rc::new(Dog { name: "Rex".to_string(), breed: "Golden".to_string() }); -// Convert to Rc for serialization -let dog_any: Rc = dog_rc.clone(); - // Serialize the Any wrapper let bytes = fory.serialize(&dog_any)?; let decoded: Rc = fory.deserialize(&bytes)?; @@ -115,22 +113,19 @@ let unwrapped = decoded.downcast_ref::().unwrap(); assert_eq!(unwrapped.name, "Rex"); ``` -For thread-safe scenarios, use `Arc`: +For thread-safe scenarios, use `Arc`: ```rust use std::sync::Arc; use std::any::Any; -let dog_arc: Arc = Arc::new(Dog { +let dog_any: Arc = Arc::new(Dog { name: "Buddy".to_string(), breed: "Labrador".to_string() }); -// Convert to Arc -let dog_any: Arc = dog_arc.clone(); - let bytes = fory.serialize(&dog_any)?; -let decoded: Arc = fory.deserialize(&bytes)?; +let decoded: Arc = fory.deserialize(&bytes)?; // Downcast to concrete type let unwrapped = decoded.downcast_ref::().unwrap(); @@ -185,7 +180,7 @@ assert_eq!(decoded.animals_arc[0].speak(), "Woof!"); Due to Rust's orphan rule, `Rc` and `Arc` cannot implement `Serializer` directly. For standalone serialization (not inside struct fields), the `register_trait_type!` macro generates wrapper types. -**Note:** If you don't want to use wrapper types, you can serialize as `Rc` or `Arc` instead (see the dyn Any section above). +**Note:** If you don't want to use wrapper types, you can serialize as `Rc` or `Arc` instead (see the dyn Any section above). The `register_trait_type!` macro generates `AnimalRc` and `AnimalArc` wrapper types: diff --git a/docs/guide/rust/schema-evolution.md b/docs/guide/rust/schema-evolution.md index 5886870759..2766641ca9 100644 --- a/docs/guide/rust/schema-evolution.md +++ b/docs/guide/rust/schema-evolution.md @@ -134,7 +134,9 @@ For typed ADT unions whose schema cases are unit or single-payload variants, forward-compatibility carrier. It cannot be the default variant, and the union must include at least one real schema case. The marker only selects the carrier and does not add an entry to the schema case table; schema cases use -non-negative IDs. +non-negative IDs. `UnknownCase` stores its payload as +`Arc`, so locally registered future payload types must be +thread-safe to be preserved as unknown cases. ### Enum Schema Evolution diff --git a/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs b/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs index 5d28da0f94..3e68444e14 100644 --- a/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs +++ b/integration_tests/idl_tests/rust/tests/idl_roundtrip.rs @@ -16,11 +16,11 @@ // under the License. use std::collections::HashMap; -use std::rc::Rc; +use std::sync::Arc; use std::{env, fs}; use chrono::NaiveDate; -use fory::{BFloat16, Float16, Fory, RcWeak}; +use fory::{ArcWeak, BFloat16, Float16, Fory}; use idl_tests::generated::addressbook::{ self, person::{PhoneNumber, PhoneType}, @@ -303,41 +303,41 @@ fn build_optional_holder() -> OptionalHolder { fn build_any_holder() -> AnyHolder { AnyHolder { - bool_value: Box::new(true), - string_value: Box::new("hello".to_string()), - date_value: Box::new(NaiveDate::from_ymd_opt(2024, 1, 2).unwrap()), - timestamp_value: Box::new( + bool_value: Arc::new(true), + string_value: Arc::new("hello".to_string()), + date_value: Arc::new(NaiveDate::from_ymd_opt(2024, 1, 2).unwrap()), + timestamp_value: Arc::new( NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() .and_hms_opt(3, 4, 5) .expect("timestamp"), ), - message_value: Box::new(AnyInner { + message_value: Arc::new(AnyInner { name: "inner".to_string(), }), - union_value: Box::new(AnyUnion::Text("union".to_string())), - list_value: Box::new("list-placeholder".to_string()), - map_value: Box::new("map-placeholder".to_string()), + union_value: Arc::new(AnyUnion::Text("union".to_string())), + list_value: Arc::new("list-placeholder".to_string()), + map_value: Arc::new("map-placeholder".to_string()), } } fn build_any_holder_with_collections() -> AnyHolder { AnyHolder { - bool_value: Box::new(true), - string_value: Box::new("hello".to_string()), - date_value: Box::new(NaiveDate::from_ymd_opt(2024, 1, 2).unwrap()), - timestamp_value: Box::new( + bool_value: Arc::new(true), + string_value: Arc::new("hello".to_string()), + date_value: Arc::new(NaiveDate::from_ymd_opt(2024, 1, 2).unwrap()), + timestamp_value: Arc::new( NaiveDate::from_ymd_opt(2024, 1, 2) .unwrap() .and_hms_opt(3, 4, 5) .expect("timestamp"), ), - message_value: Box::new(AnyInner { + message_value: Arc::new(AnyInner { name: "inner".to_string(), }), - union_value: Box::new(AnyUnion::Text("union".to_string())), - list_value: Box::new(vec!["alpha".to_string(), "beta".to_string()]), - map_value: Box::new(HashMap::from([ + union_value: Arc::new(AnyUnion::Text("union".to_string())), + list_value: Arc::new(vec!["alpha".to_string(), "beta".to_string()]), + map_value: Arc::new(HashMap::from([ ("k1".to_string(), "v1".to_string()), ("k2".to_string(), "v2".to_string()), ])), @@ -638,31 +638,31 @@ fn assert_any_holder(holder: &AnyHolder) { } fn build_tree() -> tree::TreeNode { - let mut child_a = Rc::new(tree::TreeNode { + let mut child_a = Arc::new(tree::TreeNode { id: "child-a".to_string(), name: "child-a".to_string(), children: vec![], parent: None, }); - let mut child_b = Rc::new(tree::TreeNode { + let mut child_b = Arc::new(tree::TreeNode { id: "child-b".to_string(), name: "child-b".to_string(), children: vec![], parent: None, }); - let child_a_weak = RcWeak::from(&child_a); - let child_b_weak = RcWeak::from(&child_b); - Rc::get_mut(&mut child_a).expect("child a unique").parent = Some(child_b_weak); - Rc::get_mut(&mut child_b).expect("child b unique").parent = Some(child_a_weak); + let child_a_weak = ArcWeak::from(&child_a); + let child_b_weak = ArcWeak::from(&child_b); + Arc::get_mut(&mut child_a).expect("child a unique").parent = Some(child_b_weak); + Arc::get_mut(&mut child_b).expect("child b unique").parent = Some(child_a_weak); tree::TreeNode { id: "root".to_string(), name: "root".to_string(), children: vec![ - Rc::clone(&child_a), - Rc::clone(&child_a), - Rc::clone(&child_b), + Arc::clone(&child_a), + Arc::clone(&child_a), + Arc::clone(&child_b), ], parent: None, } @@ -670,8 +670,8 @@ fn build_tree() -> tree::TreeNode { fn assert_tree(root: &tree::TreeNode) { assert_eq!(root.children.len(), 3); - assert!(Rc::ptr_eq(&root.children[0], &root.children[1])); - assert!(!Rc::ptr_eq(&root.children[0], &root.children[2])); + assert!(Arc::ptr_eq(&root.children[0], &root.children[1])); + assert!(!Arc::ptr_eq(&root.children[0], &root.children[2])); let parent_a = root.children[0] .parent .as_ref() @@ -684,36 +684,36 @@ fn assert_tree(root: &tree::TreeNode) { .expect("child b parent") .upgrade() .expect("upgrade child b parent"); - assert!(Rc::ptr_eq(&parent_a, &root.children[2])); - assert!(Rc::ptr_eq(&parent_b, &root.children[0])); + assert!(Arc::ptr_eq(&parent_a, &root.children[2])); + assert!(Arc::ptr_eq(&parent_b, &root.children[0])); } fn build_graph() -> graph::Graph { - let mut node_a = Rc::new(graph::Node { + let mut node_a = Arc::new(graph::Node { id: "node-a".to_string(), out_edges: vec![], in_edges: vec![], }); - let mut node_b = Rc::new(graph::Node { + let mut node_b = Arc::new(graph::Node { id: "node-b".to_string(), out_edges: vec![], in_edges: vec![], }); - let edge = Rc::new(graph::Edge { + let edge = Arc::new(graph::Edge { id: "edge-1".to_string(), weight: 1.5_f32, - from: Some(RcWeak::from(&node_a)), - to: Some(RcWeak::from(&node_b)), + from: Some(ArcWeak::from(&node_a)), + to: Some(ArcWeak::from(&node_b)), }); - Rc::get_mut(&mut node_a).expect("node a unique").out_edges = vec![Rc::clone(&edge)]; - Rc::get_mut(&mut node_a).expect("node a unique").in_edges = vec![Rc::clone(&edge)]; - Rc::get_mut(&mut node_b).expect("node b unique").in_edges = vec![Rc::clone(&edge)]; + Arc::get_mut(&mut node_a).expect("node a unique").out_edges = vec![Arc::clone(&edge)]; + Arc::get_mut(&mut node_a).expect("node a unique").in_edges = vec![Arc::clone(&edge)]; + Arc::get_mut(&mut node_b).expect("node b unique").in_edges = vec![Arc::clone(&edge)]; graph::Graph { - nodes: vec![Rc::clone(&node_a), Rc::clone(&node_b)], - edges: vec![Rc::clone(&edge)], + nodes: vec![Arc::clone(&node_a), Arc::clone(&node_b)], + edges: vec![Arc::clone(&edge)], } } @@ -723,8 +723,8 @@ fn assert_graph(value: &graph::Graph) { let node_a = &value.nodes[0]; let node_b = &value.nodes[1]; let edge = &value.edges[0]; - assert!(Rc::ptr_eq(&node_a.out_edges[0], &node_a.in_edges[0])); - assert!(Rc::ptr_eq(&node_a.out_edges[0], edge)); + assert!(Arc::ptr_eq(&node_a.out_edges[0], &node_a.in_edges[0])); + assert!(Arc::ptr_eq(&node_a.out_edges[0], edge)); let from = edge .from .as_ref() @@ -737,8 +737,8 @@ fn assert_graph(value: &graph::Graph) { .expect("edge to") .upgrade() .expect("upgrade to"); - assert!(Rc::ptr_eq(&from, node_a)); - assert!(Rc::ptr_eq(&to, node_b)); + assert!(Arc::ptr_eq(&from, node_a)); + assert!(Arc::ptr_eq(&to, node_b)); } #[test] diff --git a/rust/README.md b/rust/README.md index 3b381f6909..dd9ea0708d 100644 --- a/rust/README.md +++ b/rust/README.md @@ -248,7 +248,7 @@ The examples in this section use native mode because Rust trait objects and `dyn - `Box` - Owned trait objects - `Rc` - Reference-counted trait objects - `Arc` - Thread-safe reference-counted trait objects -- `Box`/`Rc`/`Arc` - Any trait type objects +- `Box`/`Rc`/`Arc` - Any trait type objects - `Vec>`, `HashMap>` - Collections of trait objects **Basic Trait Object Serialization Example:** diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index 7983cc0c81..bb9ae8e536 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -63,7 +63,10 @@ type ReadFn = type WriteDataFn = fn(&dyn Any, &mut WriteContext, has_generics: bool) -> Result<(), Error>; type ReadDataFn = fn(&mut ReadContext) -> Result, Error>; +type ReadDataSendSyncFn = fn(&mut ReadContext) -> Result, Error>; type ReadCompatibleFn = fn(&mut ReadContext, Rc) -> Result, Error>; +type ReadCompatibleSendSyncFn = + fn(&mut ReadContext, Rc) -> Result, Error>; type ToSerializerFn = fn(Box) -> Result, Error>; type BuildTypeInfosFn = fn(&TypeResolver) -> Result, Error>; const EMPTY_STRING: String = String::new(); @@ -87,44 +90,30 @@ pub struct Harness { read_fn: ReadFn, write_data_fn: WriteDataFn, read_data_fn: ReadDataFn, + read_data_send_sync_fn: ReadDataSendSyncFn, read_compatible_fn: Option, + read_compatible_send_sync_fn: Option, + threadsafe: bool, to_serializer: ToSerializerFn, build_type_infos: BuildTypeInfosFn, } impl Harness { - pub fn new( - write_fn: WriteFn, - read_fn: ReadFn, - write_data_fn: WriteDataFn, - read_data_fn: ReadDataFn, - read_compatible_fn: Option, - to_serializer: ToSerializerFn, - build_type_infos: BuildTypeInfosFn, - ) -> Harness { + pub fn stub() -> Harness { Harness { - write_fn, - read_fn, - write_data_fn, - read_data_fn, - read_compatible_fn, - to_serializer, - build_type_infos, + write_fn: stub_write_fn, + read_fn: stub_read_fn, + write_data_fn: stub_write_data_fn, + read_data_fn: stub_read_data_fn, + read_data_send_sync_fn: stub_read_data_send_sync_fn, + read_compatible_fn: None, + read_compatible_send_sync_fn: None, + threadsafe: false, + to_serializer: stub_to_serializer_fn, + build_type_infos: stub_build_type_infos, } } - pub fn stub() -> Harness { - Harness::new( - stub_write_fn, - stub_read_fn, - stub_write_data_fn, - stub_read_data_fn, - None, - stub_to_serializer_fn, - stub_build_type_infos, - ) - } - #[inline(always)] pub fn get_write_fn(&self) -> WriteFn { self.write_fn @@ -172,6 +161,29 @@ impl Harness { } (self.read_data_fn)(context) } + + /// Reads polymorphic data for `Arc` carriers. + /// This path never upgrades an ordinary `Box`; it delegates to + /// type-owned readers that construct the send-sync trait object directly. + #[inline(always)] + pub fn read_polymorphic_data_send_sync( + &self, + context: &mut ReadContext, + typeinfo: &Rc, + ) -> Result, Error> { + if !self.threadsafe { + return Err(Error::type_error(format!( + "{}::{} cannot be represented as Arc", + typeinfo.namespace.original, typeinfo.type_name.original + ))); + } + if context.is_compatible() { + if let Some(read_compatible_fn) = self.read_compatible_send_sync_fn { + return read_compatible_fn(context, typeinfo.clone()); + } + } + (self.read_data_send_sync_fn)(context) + } } #[derive(Clone, Debug)] @@ -314,15 +326,18 @@ impl TypeInfo { h.clone() } else { // Create a stub harness that returns errors when called - Harness::new( - stub_write_fn, - stub_read_fn, - stub_write_data_fn, - stub_read_data_fn, - None, - stub_to_serializer_fn, - stub_build_type_infos, - ) + Harness { + write_fn: stub_write_fn, + read_fn: stub_read_fn, + write_data_fn: stub_write_data_fn, + read_data_fn: stub_read_data_fn, + read_data_send_sync_fn: stub_read_data_send_sync_fn, + read_compatible_fn: None, + read_compatible_send_sync_fn: None, + threadsafe: false, + to_serializer: stub_to_serializer_fn, + build_type_infos: stub_build_type_infos, + } }; TypeInfo { @@ -369,6 +384,12 @@ fn stub_read_data_fn(_: &mut ReadContext) -> Result, Error> { )) } +fn stub_read_data_send_sync_fn(_: &mut ReadContext) -> Result, Error> { + Err(Error::type_error( + "Cannot deserialize unknown remote type as Arc - type not registered locally", + )) +} + fn stub_to_serializer_fn(_: Box) -> Result, Error> { Err(Error::type_error( "Cannot convert unknown remote type to serializer", @@ -945,6 +966,20 @@ impl TypeResolver { } } + fn read_data_send_sync( + context: &mut ReadContext, + ) -> Result, Error> { + if T2::fory_is_threadsafe_type() { + T2::fory_read_data_send_sync(context) + } else if crate::serializer::is_known_threadsafe_static_type_id( + T2::fory_static_type_id(), + ) { + crate::serializer::read_known_threadsafe_data::(context) + } else { + Err(crate::serializer::unsupported_threadsafe_type::()) + } + } + fn to_serializer( boxed_any: Box, ) -> Result, Error> { @@ -967,15 +1002,26 @@ impl TypeResolver { Ok(Box::new(T2::fory_read_compatible(context, type_info)?)) } - let harness = Harness::new( - write::, - read::, - write_data::, - read_data::, - Some(read_compatible::), - to_serializer::, - build_type_infos::, - ); + fn read_compatible_send_sync( + context: &mut ReadContext, + type_info: Rc, + ) -> Result, Error> { + T2::fory_read_compatible_send_sync(context, type_info) + } + + let harness = Harness { + write_fn: write::, + read_fn: read::, + write_data_fn: write_data::, + read_data_fn: read_data::, + read_data_send_sync_fn: read_data_send_sync::, + read_compatible_fn: Some(read_compatible::), + read_compatible_send_sync_fn: Some(read_compatible_send_sync::), + threadsafe: T::fory_is_threadsafe_type() + || crate::serializer::is_known_threadsafe_static_type_id(T::fory_static_type_id()), + to_serializer: to_serializer::, + build_type_infos: build_type_infos::, + }; let type_info = TypeInfo::new( actual_type_id, user_type_id, @@ -1166,6 +1212,20 @@ impl TypeResolver { } } + fn read_data_send_sync( + context: &mut ReadContext, + ) -> Result, Error> { + if T2::fory_is_threadsafe_type() { + T2::fory_read_data_send_sync(context) + } else if crate::serializer::is_known_threadsafe_static_type_id( + T2::fory_static_type_id(), + ) { + crate::serializer::read_known_threadsafe_data::(context) + } else { + Err(crate::serializer::unsupported_threadsafe_type::()) + } + } + fn to_serializer( boxed_any: Box, ) -> Result, Error> { @@ -1191,15 +1251,19 @@ impl TypeResolver { } // EXT types don't support fory_read_compatible - let harness = Harness::new( - write::, - read::, - write_data::, - read_data::, - None, - to_serializer::, - build_type_infos::, - ); + let harness = Harness { + write_fn: write::, + read_fn: read::, + write_data_fn: write_data::, + read_data_fn: read_data::, + read_data_send_sync_fn: read_data_send_sync::, + read_compatible_fn: None, + read_compatible_send_sync_fn: None, + threadsafe: T::fory_is_threadsafe_type() + || crate::serializer::is_known_threadsafe_static_type_id(T::fory_static_type_id()), + to_serializer: to_serializer::, + build_type_infos: build_type_infos::, + }; let user_type_id = if register_by_name { NO_USER_TYPE_ID diff --git a/rust/fory-core/src/serializer/any.rs b/rust/fory-core/src/serializer/any.rs index 9ac76de6c4..e4e9067be8 100644 --- a/rust/fory-core/src/serializer/any.rs +++ b/rust/fory-core/src/serializer/any.rs @@ -46,7 +46,7 @@ pub(crate) fn check_generic_container_type(type_info: &TypeInfo) -> Result<(), E if type_id == TypeId::LIST || type_id == TypeId::SET || type_id == TypeId::MAP { return Err(Error::type_error( "Cannot deserialize generic container types (Vec, HashSet, HashMap) polymorphically \ - via Box/Rc/Arc/Weak. The serialization protocol does not preserve the element type \ + via Box, Rc, or Arc. The serialization protocol does not preserve the element type \ information needed to distinguish between different generic instantiations \ (e.g., Vec vs Vec). Consider wrapping the container in a \ named struct type instead.", @@ -322,6 +322,13 @@ impl Serializer for Rc { true } + fn fory_is_threadsafe_type() -> bool + where + Self: Sized, + { + true + } + fn fory_is_polymorphic() -> bool { true } @@ -402,13 +409,13 @@ pub fn read_rc_any( } } -impl ForyDefault for Arc { +impl ForyDefault for Arc { fn fory_default() -> Self { Arc::new(()) } } -impl Serializer for Arc { +impl Serializer for Arc { fn fory_write( &self, context: &mut WriteContext, @@ -421,18 +428,19 @@ impl Serializer for Arc { .ref_writer .try_write_arc_ref(&mut context.writer, self) { - let concrete_type_id: std::any::TypeId = (**self).type_id(); + let value: &dyn Any = self.as_ref(); + let concrete_type_id: std::any::TypeId = value.type_id(); if write_type_info { let typeinfo = context.write_any_type_info(TypeId::UNKNOWN as u32, concrete_type_id)?; let serializer_fn = typeinfo.get_harness().get_write_data_fn(); - serializer_fn(&**self, context, has_generics)?; + serializer_fn(value, context, has_generics)?; } else { let serializer_fn = context .get_type_info(&concrete_type_id)? .get_harness() .get_write_data_fn(); - serializer_fn(&**self, context, has_generics)?; + serializer_fn(value, context, has_generics)?; } } Ok(()) @@ -470,24 +478,23 @@ impl Serializer for Arc { } fn fory_read_data(_: &mut ReadContext) -> Result { - Err(Error::not_allowed(format!( - "fory_read_data should not be called directly on polymorphic Rc trait object", - stringify!($trait_name) - ))) + Err(Error::not_allowed( + "fory_read_data should not be called directly on polymorphic Arc trait object" + )) } fn fory_get_type_id(_type_resolver: &TypeResolver) -> Result { Err(Error::type_error( - "Arc has no static type ID - use fory_type_id_dyn", + "Arc has no static type ID - use fory_type_id_dyn", )) } fn fory_type_id_dyn(&self, type_resolver: &TypeResolver) -> Result { - resolve_registered_type_id(type_resolver, (**self).type_id()) + resolve_registered_type_id(type_resolver, self.as_ref().type_id()) } fn fory_concrete_type_id(&self) -> std::any::TypeId { - (**self).type_id() + self.as_ref().type_id() } fn fory_is_polymorphic() -> bool { @@ -503,17 +510,31 @@ impl Serializer for Arc { } fn fory_write_type_info(_context: &mut WriteContext) -> Result<(), Error> { - // Arc is polymorphic - type info is written per element + // Arc is polymorphic - type info is written per element Ok(()) } fn fory_read_type_info(_context: &mut ReadContext) -> Result<(), Error> { - // Arc is polymorphic - type info is read per element + // Arc is polymorphic - type info is read per element Ok(()) } + fn fory_read_data_send_sync( + context: &mut ReadContext, + ) -> Result, Error> + where + Self: Sized + ForyDefault, + { + Ok(crate::serializer::box_send_sync(read_arc_any( + context, + RefMode::None, + true, + None, + )?)) + } + fn as_any(&self) -> &dyn Any { - &**self + self.as_ref() } } @@ -522,21 +543,26 @@ pub fn read_arc_any( ref_mode: RefMode, read_type_info: bool, type_info: Option>, -) -> Result, Error> { +) -> Result, Error> { let ref_flag = if ref_mode != RefMode::None { context.ref_reader.read_ref_flag(&mut context.reader)? } else { RefFlag::NotNullValue }; match ref_flag { - RefFlag::Null => Err(Error::invalid_ref("Arc cannot be null")), + RefFlag::Null => Err(Error::invalid_ref( + "Arc cannot be null", + )), RefFlag::Ref => { let ref_id = context.ref_reader.read_ref_id(&mut context.reader)?; context .ref_reader - .get_arc_ref::(ref_id) + .get_arc_ref::(ref_id) .ok_or_else(|| { - Error::invalid_data(format!("Arc reference {} not found", ref_id)) + Error::invalid_data(format!( + "Arc reference {} not found", + ref_id + )) }) } RefFlag::NotNullValue => { @@ -544,32 +570,34 @@ pub fn read_arc_any( let typeinfo = if read_type_info { context.read_any_type_info()? } else { - type_info - .ok_or_else(|| Error::type_error("No type info found for read Arc"))? + type_info.ok_or_else(|| { + Error::type_error("No type info found for read Arc") + })? }; // Check for generic container types which cannot be deserialized polymorphically check_generic_container_type(&typeinfo)?; let boxed = typeinfo .get_harness() - .read_polymorphic_data(context, &typeinfo)?; + .read_polymorphic_data_send_sync(context, &typeinfo)?; context.dec_depth(); - Ok(Arc::::from(boxed)) + Ok(Arc::::from(boxed)) } RefFlag::RefValue => { context.inc_depth()?; let typeinfo = if read_type_info { context.read_any_type_info()? } else { - type_info - .ok_or_else(|| Error::type_error("No type info found for read Arc"))? + type_info.ok_or_else(|| { + Error::type_error("No type info found for read Arc") + })? }; // Check for generic container types which cannot be deserialized polymorphically check_generic_container_type(&typeinfo)?; let boxed = typeinfo .get_harness() - .read_polymorphic_data(context, &typeinfo)?; + .read_polymorphic_data_send_sync(context, &typeinfo)?; context.dec_depth(); - let arc: Arc = Arc::from(boxed); + let arc: Arc = Arc::from(boxed); context.ref_reader.store_arc_ref(arc.clone()); Ok(arc) } diff --git a/rust/fory-core/src/serializer/arc.rs b/rust/fory-core/src/serializer/arc.rs index a67746820c..5535dc7063 100644 --- a/rust/fory-core/src/serializer/arc.rs +++ b/rust/fory-core/src/serializer/arc.rs @@ -29,6 +29,14 @@ impl Serializer for Arc true } + #[inline(always)] + fn fory_is_threadsafe_type() -> bool + where + Self: Sized, + { + true + } + fn fory_write( &self, context: &mut WriteContext, @@ -120,6 +128,18 @@ impl Serializer for Arc Ok(Arc::new(inner)) } + #[inline] + fn fory_read_data_send_sync( + context: &mut ReadContext, + ) -> Result, Error> + where + Self: Sized + ForyDefault, + { + Ok(crate::serializer::box_send_sync(Self::fory_read_data( + context, + )?)) + } + fn fory_read_type_info(context: &mut ReadContext) -> Result<(), Error> { T::fory_read_type_info(context) } diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index 6ea272f067..d8006f692e 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -2748,7 +2748,7 @@ macro_rules! any_codec { any_codec!(AnyBoxCodec, Box); any_codec!(AnyRcCodec, Rc); -any_codec!(AnyArcCodec, Arc); +any_codec!(AnyArcCodec, Arc); #[cfg(test)] mod tests { diff --git a/rust/fory-core/src/serializer/core.rs b/rust/fory-core/src/serializer/core.rs index da4d91010f..25a9f0c94e 100644 --- a/rust/fory-core/src/serializer/core.rs +++ b/rust/fory-core/src/serializer/core.rs @@ -890,6 +890,35 @@ pub trait Serializer: 'static { where Self: Sized + ForyDefault; + /// Whether this serialized value can be safely represented behind + /// `Arc` after a dynamic read. + /// + /// The default is conservative. Implementations should return true only + /// when the concrete value produced by this serializer is `Send + Sync`. + #[inline(always)] + fn fory_is_threadsafe_type() -> bool + where + Self: Sized, + { + false + } + + /// Deserialize data for dynamic thread-safe carriers. + /// + /// This method must construct the `Box` from the + /// concrete value before it is erased as `dyn Any`. It must never try to + /// upgrade a `Box` returned by the ordinary dynamic read path. + #[inline(always)] + fn fory_read_data_send_sync( + context: &mut ReadContext, + ) -> Result, Error> + where + Self: Sized + ForyDefault, + { + let _ = context; + Err(unsupported_threadsafe_type::()) + } + /// Read and validate type metadata from the buffer. /// /// This method reads type information to verify that the data in the buffer @@ -1437,6 +1466,23 @@ pub trait StructSerializer: Serializer + 'static { ) -> Result where Self: Sized; + + /// Deserialize compatible data for dynamic thread-safe carriers. + /// + /// Implementations are generated only when the resulting struct/enum is + /// known to be `Send + Sync`. + #[inline(always)] + fn fory_read_compatible_send_sync( + context: &mut ReadContext, + type_info: Rc, + ) -> Result, Error> + where + Self: Sized, + { + let _ = context; + let _ = type_info; + Err(unsupported_threadsafe_type::()) + } } /// Serializes an object implementing `Serializer` to the write context. @@ -1456,3 +1502,175 @@ pub fn write_data(this: &T, context: &mut WriteContext) -> Result pub fn read_data(context: &mut ReadContext) -> Result { T::fory_read_data(context) } + +#[inline(always)] +pub fn box_send_sync(value: T) -> Box +where + T: Any + Send + Sync, +{ + Box::new(value) +} + +#[inline(always)] +pub(crate) fn is_known_threadsafe_static_type_id(type_id: TypeId) -> bool { + matches!( + type_id, + TypeId::NONE + | TypeId::BOOL + | TypeId::INT8 + | TypeId::INT16 + | TypeId::INT32 + | TypeId::VARINT32 + | TypeId::INT64 + | TypeId::VARINT64 + | TypeId::TAGGED_INT64 + | TypeId::UINT8 + | TypeId::UINT16 + | TypeId::UINT32 + | TypeId::VAR_UINT32 + | TypeId::UINT64 + | TypeId::VAR_UINT64 + | TypeId::TAGGED_UINT64 + | TypeId::FLOAT16 + | TypeId::BFLOAT16 + | TypeId::FLOAT32 + | TypeId::FLOAT64 + | TypeId::STRING + | TypeId::DURATION + | TypeId::TIMESTAMP + | TypeId::DATE + | TypeId::DECIMAL + | TypeId::BINARY + | TypeId::BOOL_ARRAY + | TypeId::INT8_ARRAY + | TypeId::INT16_ARRAY + | TypeId::INT32_ARRAY + | TypeId::INT64_ARRAY + | TypeId::UINT8_ARRAY + | TypeId::UINT16_ARRAY + | TypeId::UINT32_ARRAY + | TypeId::UINT64_ARRAY + | TypeId::FLOAT16_ARRAY + | TypeId::BFLOAT16_ARRAY + | TypeId::FLOAT32_ARRAY + | TypeId::FLOAT64_ARRAY + | TypeId::U128 + | TypeId::INT128 + | TypeId::USIZE + | TypeId::ISIZE + | TypeId::U128_ARRAY + | TypeId::INT128_ARRAY + | TypeId::USIZE_ARRAY + | TypeId::ISIZE_ARRAY + ) +} + +pub(crate) fn read_known_threadsafe_data( + context: &mut ReadContext, +) -> Result, Error> +where + T: Serializer + ForyDefault, +{ + let boxed: Box = Box::new(T::fory_read_data(context)?); + box_known_threadsafe_data(T::fory_static_type_id(), boxed) +} + +#[cold] +#[inline(never)] +fn unexpected_threadsafe_type_id(type_id: TypeId) -> Error { + Error::type_error(format!( + "deserialized value did not match thread-safe static type id {:?}", + type_id + )) +} + +macro_rules! downcast_threadsafe_data { + ($boxed:expr, $type_id:expr, $ty:ty) => { + $boxed + .downcast::<$ty>() + .map(|value| value as Box) + .map_err(|_| unexpected_threadsafe_type_id($type_id)) + }; +} + +fn box_known_threadsafe_data( + type_id: TypeId, + boxed: Box, +) -> Result, Error> { + match type_id { + TypeId::NONE => downcast_threadsafe_data!(boxed, type_id, ()), + TypeId::BOOL => downcast_threadsafe_data!(boxed, type_id, bool), + TypeId::INT8 => downcast_threadsafe_data!(boxed, type_id, i8), + TypeId::INT16 => downcast_threadsafe_data!(boxed, type_id, i16), + TypeId::INT32 | TypeId::VARINT32 => downcast_threadsafe_data!(boxed, type_id, i32), + TypeId::INT64 | TypeId::VARINT64 | TypeId::TAGGED_INT64 => { + downcast_threadsafe_data!(boxed, type_id, i64) + } + TypeId::UINT8 => downcast_threadsafe_data!(boxed, type_id, u8), + TypeId::UINT16 => downcast_threadsafe_data!(boxed, type_id, u16), + TypeId::UINT32 | TypeId::VAR_UINT32 => downcast_threadsafe_data!(boxed, type_id, u32), + TypeId::UINT64 | TypeId::VAR_UINT64 | TypeId::TAGGED_UINT64 => { + downcast_threadsafe_data!(boxed, type_id, u64) + } + TypeId::FLOAT16 => { + downcast_threadsafe_data!(boxed, type_id, crate::types::float16::float16) + } + TypeId::BFLOAT16 => { + downcast_threadsafe_data!(boxed, type_id, crate::types::bfloat16::bfloat16) + } + TypeId::FLOAT32 => downcast_threadsafe_data!(boxed, type_id, f32), + TypeId::FLOAT64 => downcast_threadsafe_data!(boxed, type_id, f64), + TypeId::STRING => downcast_threadsafe_data!(boxed, type_id, String), + TypeId::DURATION => downcast_threadsafe_data!(boxed, type_id, crate::types::Duration), + TypeId::TIMESTAMP => downcast_threadsafe_data!(boxed, type_id, crate::types::Timestamp), + TypeId::DATE => downcast_threadsafe_data!(boxed, type_id, crate::types::Date), + TypeId::DECIMAL => downcast_threadsafe_data!(boxed, type_id, crate::types::Decimal), + TypeId::BINARY | TypeId::UINT8_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::BOOL_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::INT8_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::INT16_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::INT32_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::INT64_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::UINT16_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::UINT32_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::UINT64_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::FLOAT16_ARRAY => { + downcast_threadsafe_data!(boxed, type_id, Vec) + } + TypeId::BFLOAT16_ARRAY => { + downcast_threadsafe_data!(boxed, type_id, Vec) + } + TypeId::FLOAT32_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::FLOAT64_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::U128 => downcast_threadsafe_data!(boxed, type_id, u128), + TypeId::INT128 => downcast_threadsafe_data!(boxed, type_id, i128), + TypeId::USIZE => downcast_threadsafe_data!(boxed, type_id, usize), + TypeId::ISIZE => downcast_threadsafe_data!(boxed, type_id, isize), + TypeId::U128_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::INT128_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::USIZE_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + TypeId::ISIZE_ARRAY => downcast_threadsafe_data!(boxed, type_id, Vec), + _ => Err(unsupported_threadsafe_type_id(type_id)), + } +} + +#[cold] +#[inline(never)] +pub(crate) fn unsupported_threadsafe_type_id(type_id: TypeId) -> Error { + Error::type_error(format!( + "{:?} cannot be represented as Arc", + type_id + )) +} + +#[cold] +#[inline(never)] +pub fn unsupported_threadsafe_type() -> Error +where + T: ?Sized, +{ + Error::type_error(format!( + "{} cannot be represented as Arc", + std::any::type_name::() + )) +} diff --git a/rust/fory-core/src/serializer/mod.rs b/rust/fory-core/src/serializer/mod.rs index 8e8e23a41a..5de0bc095e 100644 --- a/rust/fory-core/src/serializer/mod.rs +++ b/rust/fory-core/src/serializer/mod.rs @@ -50,4 +50,8 @@ pub mod weak; mod core; mod decimal; pub use any::{read_box_any, write_box_any}; -pub use core::{read_data, write_data, ForyDefault, Serializer, StructSerializer}; +pub use core::{ + box_send_sync, read_data, unsupported_threadsafe_type, write_data, ForyDefault, Serializer, + StructSerializer, +}; +pub(crate) use core::{is_known_threadsafe_static_type_id, read_known_threadsafe_data}; diff --git a/rust/fory-core/src/serializer/trait_object.rs b/rust/fory-core/src/serializer/trait_object.rs index 4eb634d4fb..ce50ba1ac4 100644 --- a/rust/fory-core/src/serializer/trait_object.rs +++ b/rust/fory-core/src/serializer/trait_object.rs @@ -49,7 +49,8 @@ macro_rules! downcast_and_serialize { /// This macro automatically generates serializers for `Box` trait objects. /// Due to Rust's orphan rules, only `Box` is supported for user-defined traits. /// For `Rc` and `Arc`, wrapper types are generated (e.g., `TraitRc`, `TraitArc`), -/// either you use the wrapper types or use the `Rc` or `Arc` instead if it's not +/// either you use the wrapper types or use `Rc` or +/// `Arc` instead if it's not /// inside struct fields. For struct fields, you can use the `Rc`, `Arc` directly, /// fory will generate converters for `Rc` and `Arc` to convert to wrapper for /// diff --git a/rust/fory-core/src/serializer/unknown_case.rs b/rust/fory-core/src/serializer/unknown_case.rs index 424be4ba26..7423f384e9 100644 --- a/rust/fory-core/src/serializer/unknown_case.rs +++ b/rust/fory-core/src/serializer/unknown_case.rs @@ -48,7 +48,7 @@ fn write_typed_payload(context: &mut WriteContext, unknown: &UnknownCase) -> Res // polymorphic value. For internal numeric ids, the id byte is the complete // Any type metadata. Scalar Any payloads are not ref-tracked, so their ref // metadata is always NotNullValue before the original numeric encoding. - // Other types fall back to the normal Arc path. + // Other types fall back to the normal Arc path. context.writer.write_i8(RefFlag::NotNullValue as i8); context.writer.write_u8(type_id as u8); match type_id { @@ -130,7 +130,7 @@ pub fn read_payload(context: &mut ReadContext, case_id: u32) -> Result(ref_id) + .get_arc_ref::(ref_id) .ok_or_else(|| { Error::invalid_data(format!("UnknownCase ref {} not found", ref_id)) })?; @@ -141,6 +141,14 @@ pub fn read_payload(context: &mut ReadContext, case_id: u32) -> Result { + let ref_id = if matches!(ref_flag, RefFlag::RefValue) { + // The wire ref id belongs to the unknown payload itself. Reserve it + // before reading nested payload fields so their own refs keep the + // same ids written by the normal reference engine. + Some(context.ref_reader.reserve_ref_id()) + } else { + None + }; // The unknown-case serializer owns only the union payload envelope. It must // not add a depth frame here: the decoded Any value is not a new nesting // boundary by itself, and real nested payload serializers perform their @@ -149,10 +157,10 @@ pub fn read_payload(context: &mut ReadContext, case_id: u32) -> Result = Arc::from(boxed); - if matches!(ref_flag, RefFlag::RefValue) { - context.ref_reader.store_arc_ref(value.clone()); + .read_polymorphic_data_send_sync(context, &type_info)?; + let value: Arc = Arc::from(boxed); + if let Some(ref_id) = ref_id { + context.ref_reader.store_arc_ref_at(ref_id, value.clone()); } Ok(UnknownCase::from_runtime( case_id, @@ -200,6 +208,16 @@ impl Serializer for UnknownCase { read_payload(context, 0) } + fn fory_is_threadsafe_type() -> bool { + true + } + + fn fory_read_data_send_sync( + context: &mut ReadContext, + ) -> Result, Error> { + Ok(crate::serializer::box_send_sync(read_payload(context, 0)?)) + } + fn fory_get_type_id(_: &crate::resolver::TypeResolver) -> Result { Ok(TypeId::UNKNOWN) } diff --git a/rust/fory-core/src/types/unknown_case.rs b/rust/fory-core/src/types/unknown_case.rs index b947c1a143..0496ef4f43 100644 --- a/rust/fory-core/src/types/unknown_case.rs +++ b/rust/fory-core/src/types/unknown_case.rs @@ -27,8 +27,8 @@ pub struct UnknownCase { type_id: u32, // Keep resolver TypeInfo/Rc out of the carrier. Generated unions can outlive or move // independently from the resolver context, so the carrier stores only stable metadata - // plus the dynamic payload owned by Rust's existing polymorphic Arc path. - value: Arc, + // plus a dynamic payload whose thread-safety is guaranteed by the trait object. + value: Arc, } impl UnknownCase { @@ -38,7 +38,7 @@ impl UnknownCase { /// always uses the ordinary Any writer. pub fn new(case_id: u32, value: T) -> Self where - T: Any, + T: Any + Send + Sync, { Self { case_id, @@ -55,7 +55,7 @@ impl UnknownCase { self.type_id } - pub fn value(&self) -> &dyn Any { + pub fn value(&self) -> &(dyn Any + Send + Sync) { self.value.as_ref() } @@ -63,11 +63,15 @@ impl UnknownCase { self.value.downcast_ref::() } - pub(crate) fn value_arc(&self) -> &Arc { + pub(crate) fn value_arc(&self) -> &Arc { &self.value } - pub(crate) fn from_runtime(case_id: u32, type_id: u32, value: Arc) -> Self { + pub(crate) fn from_runtime( + case_id: u32, + type_id: u32, + value: Arc, + ) -> Self { Self { case_id, type_id, @@ -117,6 +121,13 @@ mod tests { hasher.finish() } + #[test] + fn unknown_case_is_send_sync() { + fn assert_send_sync() {} + + assert_send_sync::(); + } + #[test] fn equality_uses_carrier_identity() { let first = UnknownCase::new(7, String::from("future")); @@ -130,7 +141,7 @@ mod tests { #[test] fn replay_metadata_does_not_affect_identity() { - let value: Arc = Arc::new(String::from("future")); + let value: Arc = Arc::new(String::from("future")); let first = UnknownCase::from_runtime(7, 21, value.clone()); let same_payload = UnknownCase::from_runtime(8, 5, value); diff --git a/rust/fory-derive/src/object/field_codec.rs b/rust/fory-derive/src/object/field_codec.rs index e00acad292..8543e044de 100644 --- a/rust/fory-derive/src/object/field_codec.rs +++ b/rust/fory-derive/src/object/field_codec.rs @@ -20,7 +20,9 @@ use super::field_meta::{ ForyFieldMeta, IntEncoding, }; use super::read::create_private_field_name; -use super::util::get_type_id_by_type_ast; +use super::util::{ + get_type_id_by_type_ast, trait_object_is_any_send_sync, trait_object_is_any_without_auto_traits, +}; use crate::util::{is_arc_dyn_trait, is_box_dyn_trait, is_rc_dyn_trait, SourceField}; use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; @@ -388,7 +390,7 @@ fn field_dispatch_for( && is_container_type(ty) && !is_vec_type(ty) && !contains_custom_trait_object(ty) - && !contains_exact_any_object(ty) + && !contains_any_object(ty) { return Ok(FieldDispatch::Serializer { field_type: field_type_expr_for(ty, nullable, track_ref)?, @@ -531,8 +533,8 @@ pub(crate) fn codec_type_for( )?; if contains_custom_trait_object(key_ty) || contains_custom_trait_object(value_ty) - || contains_exact_any_object(key_ty) - || contains_exact_any_object(value_ty) + || contains_any_object(key_ty) + || contains_any_object(value_ty) { return Ok(quote! { ::fory_core::serializer::codec::HashMapCodec<#key_ty, #value_ty, #key_codec, #value_codec, #nullable, #track_ref> @@ -737,14 +739,35 @@ pub(crate) fn codec_type_for( } } - if is_exact_any(ty, "Box") { - return Ok(quote! { ::fory_core::serializer::codec::AnyBoxCodec<#nullable, #track_ref> }); + if let Some(trait_obj) = any_trait_object_for(ty, "Box") { + if trait_object_is_any_without_auto_traits(trait_obj) { + return Ok( + quote! { ::fory_core::serializer::codec::AnyBoxCodec<#nullable, #track_ref> }, + ); + } + return Ok(quote! { + compile_error!("Box is the supported owned Any carrier") + }); } - if is_exact_any(ty, "Rc") { - return Ok(quote! { ::fory_core::serializer::codec::AnyRcCodec<#nullable, #track_ref> }); + if let Some(trait_obj) = any_trait_object_for(ty, "Rc") { + if trait_object_is_any_without_auto_traits(trait_obj) { + return Ok( + quote! { ::fory_core::serializer::codec::AnyRcCodec<#nullable, #track_ref> }, + ); + } + return Ok(quote! { + compile_error!("Rc is the supported single-thread Any carrier") + }); } - if is_exact_any(ty, "Arc") { - return Ok(quote! { ::fory_core::serializer::codec::AnyArcCodec<#nullable, #track_ref> }); + if let Some(trait_obj) = any_trait_object_for(ty, "Arc") { + if trait_object_is_any_send_sync(trait_obj) { + return Ok( + quote! { ::fory_core::serializer::codec::AnyArcCodec<#nullable, #track_ref> }, + ); + } + return Ok(quote! { + compile_error!("Arc is not a shared thread-safe carrier; use Arc") + }); } if let Some((_, trait_name)) = is_box_dyn_trait(ty) { @@ -943,7 +966,7 @@ fn is_container_type(ty: &Type) -> bool { } fn contains_custom_trait_object(ty: &Type) -> bool { - if !is_exact_any(ty, "Box") && is_box_dyn_trait(ty).is_some() { + if any_trait_object_for(ty, "Box").is_none() && is_box_dyn_trait(ty).is_some() { return true; } if is_rc_dyn_trait(ty).is_some() || is_arc_dyn_trait(ty).is_some() { @@ -967,22 +990,25 @@ fn contains_custom_trait_object(ty: &Type) -> bool { }) } -fn contains_exact_any_object(ty: &Type) -> bool { - if is_exact_any(ty, "Box") || is_exact_any(ty, "Rc") || is_exact_any(ty, "Arc") { +fn contains_any_object(ty: &Type) -> bool { + if any_trait_object_for(ty, "Box").is_some() + || any_trait_object_for(ty, "Rc").is_some() + || any_trait_object_for(ty, "Arc").is_some() + { return true; } if let Some(inner) = extract_option_inner_type(ty) { - return contains_exact_any_object(&inner); + return contains_any_object(&inner); } if let Type::Array(array) = ty { - return contains_exact_any_object(array.elem.as_ref()); + return contains_any_object(array.elem.as_ref()); } let Some((_, Some(args))) = type_name_and_args(ty) else { return false; }; args.iter().any(|arg| { if let GenericArgument::Type(ty) = arg { - contains_exact_any_object(ty) + contains_any_object(ty) } else { false } @@ -1252,17 +1278,17 @@ fn validate_serializer_backed_map_meta( Ok(()) } -fn is_exact_any(ty: &Type, owner: &str) -> bool { +fn any_trait_object_for<'a>(ty: &'a Type, owner: &str) -> Option<&'a syn::TypeTraitObject> { let Some((name, Some(args))) = type_name_and_args(ty) else { - return false; + return None; }; if name != owner { - return false; + return None; } let Some(GenericArgument::Type(Type::TraitObject(trait_obj))) = args.first() else { - return false; + return None; }; - trait_obj.bounds.iter().any(|bound| { + if trait_obj.bounds.iter().any(|bound| { if let syn::TypeParamBound::Trait(trait_bound) = bound { trait_bound .path @@ -1272,7 +1298,11 @@ fn is_exact_any(ty: &Type, owner: &str) -> bool { } else { false } - }) + }) { + Some(trait_obj) + } else { + None + } } fn is_primitive_array_type(ty: &Type) -> bool { diff --git a/rust/fory-derive/src/object/serializer.rs b/rust/fory-derive/src/object/serializer.rs index db1212e40a..5721dd9dc9 100644 --- a/rust/fory-derive/src/object/serializer.rs +++ b/rust/fory-derive/src/object/serializer.rs @@ -52,6 +52,8 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea } else { quote! {} }; + let threadsafe_tokens = generate_threadsafe_tokens(ast); + let serializer_threadsafe_ts = threadsafe_tokens.serializer.clone(); // StructSerializer let ( @@ -60,6 +62,7 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea fields_info_ts, variants_fields_info_ts, read_compatible_ts, + read_compatible_send_sync_ts, enum_variant_meta_types, ) = match &ast.data { syn::Data::Struct(s) => { @@ -76,6 +79,7 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea misc::gen_field_fields_info(&source_fields), quote! { ::std::result::Result::Ok(::std::vec::Vec::new()) }, // No variants for structs read::gen_read_compatible(&source_fields), + threadsafe_tokens.struct_read_compatible.clone(), vec![], // No variant meta types for structs ) } @@ -92,6 +96,7 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea ::std::result::Result::Err(::fory_core::Error::not_allowed("`fory_read_compatible` should only be invoked at struct type" )) }, + quote! {}, variant_meta_types, ) } @@ -179,6 +184,8 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea fn fory_read_compatible(context: &mut ::fory_core::ReadContext, type_info: ::std::rc::Rc<::fory_core::TypeInfo>) -> ::std::result::Result { #read_compatible_ts } + + #read_compatible_send_sync_ts } impl #impl_generics ::fory_core::Serializer for #name #ty_generics #where_clause { @@ -243,6 +250,8 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea #read_data_ts } + #serializer_threadsafe_ts + #[inline(always)] fn fory_read_type_info(context: &mut ::fory_core::ReadContext) -> ::std::result::Result<(), ::fory_core::error::Error> { #read_type_info_ts @@ -254,6 +263,84 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea code } +struct ThreadsafeTokens { + serializer: proc_macro2::TokenStream, + struct_read_compatible: proc_macro2::TokenStream, +} + +fn generate_threadsafe_tokens(ast: &syn::DeriveInput) -> ThreadsafeTokens { + if !derive_type_is_threadsafe(ast) { + return ThreadsafeTokens { + serializer: quote! {}, + struct_read_compatible: quote! {}, + }; + } + let struct_read_compatible = if matches!(ast.data, syn::Data::Struct(_)) { + quote! { + #[inline] + fn fory_read_compatible_send_sync( + context: &mut ::fory_core::ReadContext, + type_info: ::std::rc::Rc<::fory_core::TypeInfo>, + ) -> ::std::result::Result<::std::boxed::Box, ::fory_core::error::Error> { + let value = ::fory_read_compatible(context, type_info)?; + ::std::result::Result::Ok(::fory_core::serializer::box_send_sync(value)) + } + } + } else { + quote! {} + }; + ThreadsafeTokens { + serializer: quote! { + #[inline(always)] + fn fory_is_threadsafe_type() -> bool + where + Self: Sized, + { + true + } + + #[inline] + fn fory_read_data_send_sync( + context: &mut ::fory_core::ReadContext, + ) -> ::std::result::Result<::std::boxed::Box, ::fory_core::error::Error> + where + Self: Sized + ::fory_core::ForyDefault, + { + let value = ::fory_read_data(context)?; + ::std::result::Result::Ok(::fory_core::serializer::box_send_sync(value)) + } + }, + struct_read_compatible, + } +} + +fn derive_type_is_threadsafe(ast: &syn::DeriveInput) -> bool { + use crate::object::util::{ + all_type_params_send_sync, type_is_threadsafe, type_param_send_sync_bounds, + }; + + // This is a syntactic filter for generating the send-sync reader. The + // generated reader still boxes `Self`, so Rust enforces the final + // `Send + Sync` invariant for nested user-defined field types. + if !all_type_params_send_sync(&ast.generics) { + return false; + } + let send_sync_params = type_param_send_sync_bounds(&ast.generics); + match &ast.data { + syn::Data::Struct(data) => data + .fields + .iter() + .all(|field| type_is_threadsafe(&field.ty, &send_sync_params)), + syn::Data::Enum(data) => data.variants.iter().all(|variant| { + variant + .fields + .iter() + .all(|field| type_is_threadsafe(&field.ty, &send_sync_params)) + }), + syn::Data::Union(_) => false, + } +} + fn generate_default_impl( ast: &syn::DeriveInput, generate_default: bool, diff --git a/rust/fory-derive/src/object/util.rs b/rust/fory-derive/src/object/util.rs index f2c7e26d54..4fc38b8936 100644 --- a/rust/fory-derive/src/object/util.rs +++ b/rust/fory-derive/src/object/util.rs @@ -21,6 +21,7 @@ use fory_core::util::to_snake_case; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use std::cell::RefCell; +use std::collections::HashSet; use syn::{Field, Fields, GenericArgument, Index, PathArguments, Type}; /// Get field name for a field, handling both named and tuple struct fields. @@ -122,7 +123,7 @@ fn is_forward_field_internal(ty: &Type, struct_name: &str) -> bool { // Check smart pointers: Rc / Arc // Only return true if: - // 1. Inner type is Rc (polymorphic) + // 1. Inner type is dyn Any (polymorphic) // 2. Inner type references the containing struct (forward reference) if seg.ident == "Rc" || seg.ident == "Arc" { if let PathArguments::AngleBracketed(args) = &seg.arguments { @@ -133,12 +134,12 @@ fn is_forward_field_internal(ty: &Type, struct_name: &str) -> bool { if trait_obj .bounds .iter() - .any(|b| b.to_token_stream().to_string() == "Any") + .any(|b| trait_bound_ident(b).as_deref() == Some("Any")) { - // Rc → return true + // Rc/Arc needs polymorphic ref handling. return true; } else { - // Rc → return false + // Rc/Arc uses generated wrapper handling. return false; } } @@ -929,6 +930,155 @@ pub(crate) fn is_unknown_case_type(ty: &Type) -> bool { ) } +pub(crate) fn type_param_send_sync_bounds(generics: &syn::Generics) -> HashSet { + let mut params = HashSet::new(); + for param in generics.type_params() { + if bounds_include_send_sync(¶m.bounds) { + params.insert(param.ident.to_string()); + } + } + if let Some(where_clause) = &generics.where_clause { + for predicate in &where_clause.predicates { + let syn::WherePredicate::Type(predicate_ty) = predicate else { + continue; + }; + let Type::Path(type_path) = &predicate_ty.bounded_ty else { + continue; + }; + let Some(segment) = type_path.path.segments.last() else { + continue; + }; + if bounds_include_send_sync(&predicate_ty.bounds) { + params.insert(segment.ident.to_string()); + } + } + } + params +} + +pub(crate) fn all_type_params_send_sync(generics: &syn::Generics) -> bool { + let bounded = type_param_send_sync_bounds(generics); + generics + .type_params() + .all(|param| bounded.contains(¶m.ident.to_string())) +} + +pub(crate) fn type_is_threadsafe(ty: &Type, send_sync_params: &HashSet) -> bool { + match ty { + Type::Array(array) => type_is_threadsafe(array.elem.as_ref(), send_sync_params), + Type::Tuple(tuple) => tuple + .elems + .iter() + .all(|elem| type_is_threadsafe(elem, send_sync_params)), + Type::Path(type_path) => { + let Some(segment) = type_path.path.segments.last() else { + return false; + }; + let name = segment.ident.to_string(); + if send_sync_params.contains(&name) && matches!(segment.arguments, PathArguments::None) + { + return true; + } + match name.as_str() { + "bool" | "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" + | "u64" | "u128" | "usize" | "f32" | "f64" | "String" | "Date" | "Timestamp" + | "Duration" | "Decimal" | "float16" | "bfloat16" | "Float16" | "BFloat16" + | "UnknownCase" => true, + "Rc" | "RcWeak" | "RefCell" | "Cell" => false, + "Option" | "Vec" | "VecDeque" | "LinkedList" | "HashSet" | "BTreeSet" + | "BinaryHeap" | "Box" | "Arc" | "ArcWeak" | "Mutex" => { + let Some(inner) = first_type_arg(&segment.arguments) else { + return false; + }; + match (name.as_str(), inner) { + ("Box", Type::TraitObject(_)) => false, + ("Arc", Type::TraitObject(trait_obj)) => { + trait_object_is_any_send_sync(trait_obj) + } + (_, Type::TraitObject(_)) => false, + _ => type_is_threadsafe(inner, send_sync_params), + } + } + "HashMap" | "BTreeMap" => { + let Some((key, value)) = two_path_type_args(&segment.arguments) else { + return false; + }; + type_is_threadsafe(key, send_sync_params) + && type_is_threadsafe(value, send_sync_params) + } + _ => true, + } + } + _ => false, + } +} + +pub(crate) fn trait_object_is_any_send_sync(trait_obj: &syn::TypeTraitObject) -> bool { + trait_object_has_trait(trait_obj, "Any") + && trait_object_has_trait(trait_obj, "Send") + && trait_object_has_trait(trait_obj, "Sync") +} + +pub(crate) fn trait_object_is_any_without_auto_traits(trait_obj: &syn::TypeTraitObject) -> bool { + trait_object_has_trait(trait_obj, "Any") + && !trait_object_has_trait(trait_obj, "Send") + && !trait_object_has_trait(trait_obj, "Sync") +} + +fn first_type_arg(arguments: &PathArguments) -> Option<&Type> { + let PathArguments::AngleBracketed(args) = arguments else { + return None; + }; + args.args.iter().find_map(|arg| match arg { + GenericArgument::Type(ty) => Some(ty), + _ => None, + }) +} + +fn two_path_type_args(arguments: &PathArguments) -> Option<(&Type, &Type)> { + let PathArguments::AngleBracketed(args) = arguments else { + return None; + }; + let mut iter = args.args.iter().filter_map(|arg| match arg { + GenericArgument::Type(ty) => Some(ty), + _ => None, + }); + Some((iter.next()?, iter.next()?)) +} + +fn bounds_include_send_sync( + bounds: &syn::punctuated::Punctuated, +) -> bool { + let mut has_send = false; + let mut has_sync = false; + for bound in bounds { + match trait_bound_ident(bound).as_deref() { + Some("Send") => has_send = true, + Some("Sync") => has_sync = true, + _ => {} + } + } + has_send && has_sync +} + +fn trait_object_has_trait(trait_obj: &syn::TypeTraitObject, ident: &str) -> bool { + trait_obj + .bounds + .iter() + .any(|bound| trait_bound_ident(bound).as_deref() == Some(ident)) +} + +fn trait_bound_ident(bound: &syn::TypeParamBound) -> Option { + let syn::TypeParamBound::Trait(trait_bound) = bound else { + return None; + }; + trait_bound + .path + .segments + .last() + .map(|segment| segment.ident.to_string()) +} + // The typed-ADT forward-compatibility carrier is selected by a runtime marker, // not by a schema case id. Known schema cases may still use id 0. pub(crate) fn is_runtime_unknown_variant(variant: &syn::Variant) -> bool { diff --git a/rust/fory/src/lib.rs b/rust/fory/src/lib.rs index 78484d8cec..e190d0c8dd 100644 --- a/rust/fory/src/lib.rs +++ b/rust/fory/src/lib.rs @@ -340,7 +340,7 @@ //! - `Box` - Owned trait objects //! - `Rc` - Reference-counted trait objects //! - `Arc` - Thread-safe reference-counted trait objects -//! - `Rc` / `Arc` - Runtime type dispatch without custom traits +//! - `Rc` / `Arc` - Runtime type dispatch without custom traits //! - Collections: `Vec>`, `HashMap>` //! //! #### Basic Trait Object Serialization @@ -401,8 +401,9 @@ //! //! #### Serializing `dyn Any` Trait Objects //! -//! **What it does:** Supports serializing `Rc` and `Arc` for maximum -//! runtime type flexibility without defining custom traits. +//! **What it does:** Supports serializing `Rc` and +//! `Arc` for maximum runtime type flexibility without +//! defining custom traits. //! //! **When to use:** Plugin systems, dynamic type handling, or when you need runtime type //! dispatch without compile-time trait definitions. @@ -440,7 +441,7 @@ //! # } //! ``` //! -//! For thread-safe scenarios, use `Arc`: +//! For thread-safe scenarios, use `Arc`: //! //! ```rust //! use fory::Fory; @@ -456,12 +457,12 @@ //! let mut fory = Fory::builder().xlang(false).build(); //! fory.register::(101)?; //! -//! let cat: Arc = Arc::new(Cat { +//! let cat: Arc = Arc::new(Cat { //! name: "Whiskers".to_string() //! }); //! //! let bytes = fory.serialize(&cat)?; -//! let decoded: Arc = fory.deserialize(&bytes)?; +//! let decoded: Arc = fory.deserialize(&bytes)?; //! //! let unwrapped = decoded.downcast_ref::().unwrap(); //! assert_eq!(unwrapped.name, "Whiskers"); @@ -536,7 +537,7 @@ //! the `register_trait_type!` macro generates wrapper types. //! //! **Note:** If you don't want to use wrapper types, you can serialize as `Rc` -//! or `Arc` instead (see the `dyn Any` section above). +//! or `Arc` instead (see the `dyn Any` section above). //! //! The `register_trait_type!` macro generates `AnimalRc` and `AnimalArc` wrapper types: //! @@ -1005,7 +1006,7 @@ //! - `Rc` - Reference-counted trait objects //! - `Arc` - Thread-safe reference-counted trait objects //! - `Rc` - Runtime type dispatch without custom traits -//! - `Arc` - Thread-safe runtime type dispatch +//! - `Arc` - Thread-safe runtime type dispatch //! //! ## Wire Modes And Schema Evolution //! diff --git a/rust/tests/tests/test_any.rs b/rust/tests/tests/test_any.rs index df49db8af7..d7c4a7b574 100644 --- a/rust/tests/tests/test_any.rs +++ b/rust/tests/tests/test_any.rs @@ -77,22 +77,22 @@ fn test_rc_dyn_any() { fn test_arc_dyn_any() { let fory = Fory::builder().xlang(false).build(); - let value: Arc = Arc::new("arc test".to_string()); + let value: Arc = Arc::new("arc test".to_string()); let bytes = fory.serialize(&value).unwrap(); - let deserialized: Arc = fory.deserialize(&bytes).unwrap(); + let deserialized: Arc = fory.deserialize(&bytes).unwrap(); assert_eq!( deserialized.downcast_ref::().unwrap(), &"arc test".to_string() ); - let value2: Arc = Arc::new(123i32); + let value2: Arc = Arc::new(123i32); let bytes2 = fory.serialize(&value2).unwrap(); - let deserialized2: Arc = fory.deserialize(&bytes2).unwrap(); + let deserialized2: Arc = fory.deserialize(&bytes2).unwrap(); assert_eq!(deserialized2.downcast_ref::().unwrap(), &123i32); - let value3: Arc = Arc::new(vec![1, 2, 3]); + let value3: Arc = Arc::new(vec![1, 2, 3]); let bytes3 = fory.serialize(&value3).unwrap(); - let deserialized3: Arc = fory.deserialize(&bytes3).unwrap(); + let deserialized3: Arc = fory.deserialize(&bytes3).unwrap(); assert_eq!( deserialized3.downcast_ref::>().unwrap(), &vec![1, 2, 3] @@ -122,12 +122,12 @@ fn test_rc_dyn_any_shared_reference() { fn test_arc_dyn_any_shared_reference() { let fory = Fory::builder().xlang(false).build(); - let shared_vec: Arc = Arc::new(vec![1, 2, 3]); + let shared_vec: Arc = Arc::new(vec![1, 2, 3]); let data = vec![shared_vec.clone(), shared_vec.clone()]; let bytes = fory.serialize(&data).unwrap(); - let deserialized: Vec> = fory.deserialize(&bytes).unwrap(); + let deserialized: Vec> = fory.deserialize(&bytes).unwrap(); let first_vec = deserialized[0].downcast_ref::>().unwrap(); let second_vec = deserialized[1].downcast_ref::>().unwrap(); @@ -207,6 +207,11 @@ struct Container { items: Vec, } +#[derive(ForyStruct)] +struct ArcAnyHolder { + value: Arc, +} + #[derive(ForyStruct)] struct AnyMapVarKey { #[fory(id = 0)] @@ -257,17 +262,17 @@ fn test_arc_by_name() { items: vec!["a".to_string(), "b".to_string(), "c".to_string()], }; - let value: Arc = Arc::new(container); + let value: Arc = Arc::new(container); let bytes = fory.serialize(&value).unwrap(); - let deserialized: Arc = fory.deserialize(&bytes).unwrap(); + let deserialized: Arc = fory.deserialize(&bytes).unwrap(); let result = deserialized.downcast_ref::().unwrap(); assert_eq!(result.id, 999); assert_eq!(result.items, vec!["a", "b", "c"]); - let container_vec: Vec> = vec![value.clone(), value.clone()]; + let container_vec: Vec> = vec![value.clone(), value.clone()]; let bytes_vec = fory.serialize(&container_vec).unwrap(); - let deserialized_vec: Vec> = fory.deserialize(&bytes_vec).unwrap(); + let deserialized_vec: Vec> = fory.deserialize(&bytes_vec).unwrap(); assert_eq!(deserialized_vec.len(), 2); let first = deserialized_vec[0].downcast_ref::().unwrap(); let second = deserialized_vec[1].downcast_ref::().unwrap(); @@ -278,6 +283,28 @@ fn test_arc_by_name() { )); } +#[test] +fn test_arc_any_field_by_name() { + let mut fory = Fory::builder().xlang(false).build(); + fory.register_by_name::("", "Container").unwrap(); + fory.register_by_name::("", "ArcAnyHolder") + .unwrap(); + + let holder = ArcAnyHolder { + value: Arc::new(Container { + id: 777, + items: vec!["shared".to_string(), "any".to_string()], + }), + }; + + let bytes = fory.serialize(&holder).unwrap(); + let decoded: ArcAnyHolder = fory.deserialize(&bytes).unwrap(); + let container = decoded.value.downcast_ref::().unwrap(); + + assert_eq!(container.id, 777); + assert_eq!(container.items, vec!["shared", "any"]); +} + #[test] fn test_rc_by_name() { let mut fory = Fory::builder().xlang(false).build(); diff --git a/rust/tests/tests/test_enum.rs b/rust/tests/tests/test_enum.rs index 0b33e3b751..566baff9d8 100644 --- a/rust/tests/tests/test_enum.rs +++ b/rust/tests/tests/test_enum.rs @@ -222,6 +222,109 @@ fn union_compatible_enum_xlang_format() { assert_eq!(obj2, result2); } +#[test] +fn unknown_case_reads_threadsafe_generated_payload() { + use fory_core::ArcWeak; + use std::sync::{Arc, Mutex}; + + #[derive(ForyStruct, Debug)] + struct FutureLeaf { + label: String, + } + + #[derive(ForyStruct, Debug)] + struct FutureNode { + value: i32, + parent: ArcWeak>, + children: Vec>>, + } + + #[derive(ForyStruct, Debug)] + struct FuturePayload { + id: i32, + leaf: FutureLeaf, + primary: Arc, + alias: Arc, + root: Arc>, + } + + #[derive(ForyUnion, Debug)] + enum OldUnion { + #[fory(unknown)] + Unknown(fory_core::UnknownCase), + #[fory(id = 0, default)] + Known(String), + } + + #[derive(ForyUnion, Debug)] + enum NewUnion { + #[fory(unknown)] + Unknown(fory_core::UnknownCase), + #[fory(id = 0, default)] + Known(String), + #[fory(id = 1)] + Future(FuturePayload), + } + + let mut writer = Fory::builder() + .xlang(true) + .compatible(false) + .track_ref(true) + .build(); + writer.register::(400).unwrap(); + writer.register::(401).unwrap(); + writer.register::(402).unwrap(); + writer.register::(403).unwrap(); + + let mut reader = Fory::builder() + .xlang(true) + .compatible(false) + .track_ref(true) + .build(); + reader.register::(400).unwrap(); + reader.register::(401).unwrap(); + reader.register::(402).unwrap(); + reader.register::(403).unwrap(); + + let shared = Arc::new("shared".to_string()); + let root = Arc::new(Mutex::new(FutureNode { + value: 10, + parent: ArcWeak::new(), + children: vec![], + })); + let child = Arc::new(Mutex::new(FutureNode { + value: 20, + parent: ArcWeak::from(&root), + children: vec![], + })); + root.lock().unwrap().children.push(child); + let value = NewUnion::Future(FuturePayload { + id: 7, + leaf: FutureLeaf { + label: "nested".to_string(), + }, + primary: shared.clone(), + alias: shared, + root, + }); + + let bytes = writer.serialize(&value).unwrap(); + let decoded: OldUnion = reader.deserialize(&bytes).unwrap(); + + let OldUnion::Unknown(unknown) = decoded else { + panic!("expected unknown case"); + }; + assert_eq!(unknown.case_id(), 1); + let payload = unknown.downcast_ref::().unwrap(); + assert_eq!(payload.id, 7); + assert_eq!(payload.leaf.label, "nested"); + assert!(Arc::ptr_eq(&payload.primary, &payload.alias)); + let root = payload.root.clone(); + let child = root.lock().unwrap().children[0].clone(); + let parent = child.lock().unwrap().parent.upgrade().unwrap(); + assert!(Arc::ptr_eq(&root, &parent)); +} + #[test] fn union_payload_nested_codec_annotations_roundtrip() { #[derive(ForyUnion, Debug, PartialEq)] diff --git a/rust/tests/tests/test_helpers.rs b/rust/tests/tests/test_helpers.rs index c3afe5dbf0..3e9e626631 100644 --- a/rust/tests/tests/test_helpers.rs +++ b/rust/tests/tests/test_helpers.rs @@ -53,13 +53,13 @@ where assert_eq!(result.downcast_ref::().unwrap(), &value); } -/// Generic helper for testing Arc serialization +/// Generic helper for testing Arc serialization pub fn test_arc_any(fory: &Fory, value: T) where - T: 'static + PartialEq + std::fmt::Debug + Clone, + T: 'static + PartialEq + std::fmt::Debug + Clone + Send + Sync, { - let wrapped: Arc = Arc::new(value.clone()); + let wrapped: Arc = Arc::new(value.clone()); let bytes = fory.serialize(&wrapped).unwrap(); - let result: Arc = fory.deserialize(&bytes).unwrap(); + let result: Arc = fory.deserialize(&bytes).unwrap(); assert_eq!(result.downcast_ref::().unwrap(), &value); } diff --git a/rust/tests/tests/test_ref_resolver.rs b/rust/tests/tests/test_ref_resolver.rs index 643c077c92..23c6939478 100644 --- a/rust/tests/tests/test_ref_resolver.rs +++ b/rust/tests/tests/test_ref_resolver.rs @@ -20,6 +20,7 @@ use fory_core::buffer::Writer; use fory_core::resolver::{RefReader, RefWriter}; use fory_core::{ArcWeak, RcWeak}; +use std::any::Any; use std::rc::Rc; use std::sync::Arc; @@ -79,6 +80,20 @@ fn test_arc_storage_and_retrieval() { assert!(Arc::ptr_eq(&arc, &retrieved)); } +#[test] +fn test_arc_any_send_sync_storage_and_retrieval() { + let mut ref_reader = RefReader::new(); + let arc: Arc = Arc::new(String::from("test")); + + let ref_id = ref_reader.store_arc_ref(arc.clone()); + + let retrieved = ref_reader + .get_arc_ref::(ref_id) + .unwrap(); + assert_eq!(retrieved.downcast_ref::().unwrap(), "test"); + assert!(Arc::ptr_eq(&arc, &retrieved)); +} + #[test] fn test_ref_writer_clear() { let mut ref_writer = RefWriter::new(); diff --git a/rust/tests/tests/test_unsigned.rs b/rust/tests/tests/test_unsigned.rs index a7085215bb..4f4be267a5 100644 --- a/rust/tests/tests/test_unsigned.rs +++ b/rust/tests/tests/test_unsigned.rs @@ -435,7 +435,7 @@ fn test_unsigned_with_smart_pointers() { test_rc_any(&fory, usize::MAX); test_rc_any(&fory, u128::MAX); - // Test Arc with unsigned types + // Test Arc with unsigned types test_arc_any(&fory, u8::MAX); test_arc_any(&fory, u16::MAX); test_arc_any(&fory, u32::MAX); @@ -459,7 +459,7 @@ fn test_unsigned_with_smart_pointers() { test_rc_any(&fory, vec![0usize, 1000000000000, usize::MAX]); test_rc_any(&fory, vec![0u128, 1000000000000, u128::MAX]); - // Test Arc with unsigned arrays + // Test Arc with unsigned arrays test_arc_any(&fory, vec![0u8, 127, u8::MAX]); test_arc_any(&fory, vec![100u16, 200, 300, u16::MAX]); test_arc_any(&fory, vec![999u32, 888, 777, u32::MAX]);