Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/cgp-macro-lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ description = """
default = []

[dependencies]
syn = { version = "2.0.95", features = [ "full", "extra-traits", "visit" ] }
syn = { version = "2.0.95", features = [ "full", "extra-traits", "visit", "visit-mut" ] }
quote = "1.0.38"
proc-macro2 = "1.0.92"
prettyplease = "0.2.27"
Expand Down
6 changes: 6 additions & 0 deletions crates/cgp-macro-lib/src/cgp_impl/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use syn::punctuated::Punctuated;
use syn::token::Comma;

use crate::cgp_fn::UseTypeSpec;
use crate::cgp_impl::use_provider::UseProviderSpec;
use crate::parse::SimpleType;

pub fn parse_impl_attributes(attributes: &mut Vec<Attribute>) -> syn::Result<ImplAttributes> {
Expand All @@ -22,6 +23,10 @@ pub fn parse_impl_attributes(attributes: &mut Vec<Attribute>) -> syn::Result<Imp
let use_type = attribute
.parse_args_with(Punctuated::<UseTypeSpec, Comma>::parse_terminated)?;
parsed_attributes.use_type.extend(use_type);
} else if ident == "use_provider" {
let use_provider = attribute
.parse_args_with(Punctuated::<UseProviderSpec, Comma>::parse_terminated)?;
parsed_attributes.use_provider.extend(use_provider);
} else {
attributes.push(attribute);
}
Expand All @@ -37,4 +42,5 @@ pub fn parse_impl_attributes(attributes: &mut Vec<Attribute>) -> syn::Result<Imp
pub struct ImplAttributes {
pub uses: Vec<SimpleType>,
pub use_type: Vec<UseTypeSpec>,
pub use_provider: Vec<UseProviderSpec>,
}
20 changes: 19 additions & 1 deletion crates/cgp-macro-lib/src/cgp_impl/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Plus;
use syn::visit_mut::visit_item_impl_mut;
use syn::{ItemImpl, TypeParamBound, parse2};

use crate::cgp_fn::{apply_use_type_attributes_to_item_impl, build_implicit_args_bounds};
use crate::cgp_impl::attributes::parse_impl_attributes;
use crate::cgp_impl::provider_bounds::derive_provider_bounds;
use crate::cgp_impl::provider_call::TransformProviderCallVisitor;
use crate::cgp_impl::{ImplProviderSpec, derive_provider_impl, implicit_args};
use crate::derive_provider::{
derive_component_name_from_provider_impl, derive_is_provider_for, derive_provider_struct,
Expand Down Expand Up @@ -48,7 +51,13 @@ pub fn derive_cgp_impl(
})?);
}

let provider_impl = derive_provider_impl(&spec.provider_type, item_impl)?;
let mut visitor = TransformProviderCallVisitor::default();
visit_item_impl_mut(&mut visitor, &mut item_impl);
if let Some(err) = visitor.error {
return Err(err);
}

let (context_type, mut provider_impl) = derive_provider_impl(&spec.provider_type, item_impl)?;

let component_type = match &spec.component_type {
Some(component_type) => component_type.clone(),
Expand All @@ -57,6 +66,15 @@ pub fn derive_cgp_impl(

let is_provider_for_impl: ItemImpl = derive_is_provider_for(&component_type, &provider_impl)?;

if !attributes.use_provider.is_empty() {
let where_clause = provider_impl.generics.make_where_clause();

for spec in attributes.use_provider.iter() {
let provider_bounds = derive_provider_bounds(&context_type, spec)?;
where_clause.predicates.push(provider_bounds);
}
}

let provider_struct = if spec.new_struct {
Some(derive_provider_struct(&provider_impl)?)
} else {
Expand Down
3 changes: 3 additions & 0 deletions crates/cgp-macro-lib/src/cgp_impl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
mod attributes;
mod derive;
mod implicit_args;
mod provider_bounds;
mod provider_call;
mod provider_impl;
mod spec;
mod transform;
mod use_provider;

pub use derive::*;
pub use provider_impl::*;
Expand Down
41 changes: 41 additions & 0 deletions crates/cgp-macro-lib/src/cgp_impl/provider_bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use quote::{ToTokens, quote};
use syn::punctuated::Punctuated;
use syn::token::Plus;
use syn::{Type, TypeParamBound, WherePredicate, parse_quote, parse2};

use crate::cgp_impl::use_provider::UseProviderSpec;

pub fn derive_provider_bounds(
context_type: &Type,
spec: &UseProviderSpec,
) -> syn::Result<WherePredicate> {
let context_type = if spec.context_type == parse_quote! { Self } {
context_type
} else {
&spec.context_type
};

let provider_type = &spec.provider_type;
let mut bounds = Punctuated::<TypeParamBound, Plus>::new();

for bound in &spec.provider_trait_bounds {
let trait_ident = &bound.name;
let mut m_generics = bound.generics.clone();

let generics = m_generics.get_or_insert_with(|| parse_quote!(<>));
generics
.args
.insert(0, parse2(context_type.to_token_stream())?);

let trait_bound = parse2(quote! {
#trait_ident #generics
})?;
bounds.push(trait_bound);
}

let predicate = parse2(quote! {
#provider_type: #bounds
})?;

Ok(predicate)
}
69 changes: 69 additions & 0 deletions crates/cgp-macro-lib/src/cgp_impl/provider_call.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use quote::quote;
use syn::parse::Parse;
use syn::visit_mut::VisitMut;
use syn::{Expr, ExprMethodCall, Type, parse2};

pub fn transform_provider_call(expr: &ExprMethodCall) -> syn::Result<Option<Expr>> {
let attributes = expr.attrs.clone();
let mut out_attributes = Vec::new();

let mut m_use_provider = None;

for attribute in attributes {
if attribute.path().is_ident("use_provider") {
if m_use_provider.is_some() {
return Err(syn::Error::new_spanned(
attribute,
"Multiple #[use_provider] attributes found",
));
}

m_use_provider = Some(attribute.parse_args_with(Type::parse)?);
} else {
out_attributes.push(attribute);
}
}

if let Some(provider_type) = m_use_provider {
let mut args = expr.args.clone();
args.insert(0, expr.receiver.as_ref().clone());

let method = &expr.method;
let turbofish = &expr.turbofish;

let new_expr: Expr = parse2(quote! {
#provider_type::#method #turbofish ( #args )
})?;

Ok(Some(new_expr))
} else {
Ok(None)
}
}

#[derive(Default)]
pub struct TransformProviderCallVisitor {
pub error: Option<syn::Error>,
}

impl VisitMut for TransformProviderCallVisitor {
fn visit_expr_mut(&mut self, expr: &mut syn::Expr) {
if self.error.is_some() {
return;
}

if let syn::Expr::MethodCall(method_call) = expr {
match transform_provider_call(method_call) {
Ok(Some(new_expr)) => {
*expr = new_expr;
}
Ok(None) => {}
Err(e) => {
self.error = Some(e);
}
}
}

syn::visit_mut::visit_expr_mut(self, expr);
}
}
14 changes: 9 additions & 5 deletions crates/cgp-macro-lib/src/cgp_impl/provider_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@ use crate::cgp_impl::transform_impl_trait;
pub fn derive_provider_impl(
provider_type: &Type,
mut item_impl: ItemImpl,
) -> syn::Result<ItemImpl> {
) -> syn::Result<(Type, ItemImpl)> {
match &item_impl.trait_ {
Some((_, path, _)) => {
let consumer_trait_path = parse2(path.to_token_stream())?;
let context_type = item_impl.self_ty.as_ref();
transform_impl_trait(
let item_trait = transform_impl_trait(
&item_impl,
&consumer_trait_path,
provider_type,
context_type,
)
)?;

Ok((context_type.clone(), item_trait))
}
None => {
let consumer_trait_path = parse2(item_impl.self_ty.to_token_stream())?;
Expand All @@ -27,12 +29,14 @@ pub fn derive_provider_impl(
.params
.insert(0, parse_quote! { __Context__ });

transform_impl_trait(
let item_trait = transform_impl_trait(
&item_impl,
&consumer_trait_path,
provider_type,
&context_type,
)
)?;

Ok((context_type, item_trait))
}
}
}
28 changes: 28 additions & 0 deletions crates/cgp-macro-lib/src/cgp_impl/use_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::{Colon, Comma};
use syn::{Type, parse_quote};

use crate::parse::SimpleType;

pub struct UseProviderSpec {
pub context_type: Type,
pub provider_type: Type,
pub provider_trait_bounds: Punctuated<SimpleType, Comma>,
}

impl Parse for UseProviderSpec {
fn parse(input: ParseStream) -> syn::Result<Self> {
let context_type = parse_quote!(Self);
let provider_type = input.parse()?;

let _: Colon = input.parse()?;
let provider_trait_bounds = Punctuated::parse_terminated(input)?;

Ok(Self {
context_type,
provider_type,
provider_trait_bounds,
})
}
}
1 change: 1 addition & 0 deletions crates/cgp-tests/tests/component_tests/cgp_impl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod basic;
pub mod implicit_args;
pub mod implicit_context;
pub mod use_provider;
24 changes: 24 additions & 0 deletions crates/cgp-tests/tests/component_tests/cgp_impl/use_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use cgp::prelude::*;

#[cgp_component(AreaCalculator)]
pub trait CanCalculateArea {
fn area(&self) -> f64;
}

#[cgp_impl(new RectangleArea)]
impl AreaCalculator {
fn area(&self, #[implicit] width: f64, #[implicit] height: f64) -> f64 {
width * height
}
}

#[cgp_impl(new ScaledArea<Inner>)]
#[use_provider(Inner: AreaCalculator)]
impl<Inner> AreaCalculator {
fn area(&self, #[implicit] scale_factor: f64) -> f64 {
let base_area = #[use_provider(Inner)]
self.area();

base_area * scale_factor * scale_factor
}
}
Loading