diff --git a/src/main/java/bio/terra/service/dataset/DatasetService.java b/src/main/java/bio/terra/service/dataset/DatasetService.java index c9b1e3aa56..5bc2566b85 100644 --- a/src/main/java/bio/terra/service/dataset/DatasetService.java +++ b/src/main/java/bio/terra/service/dataset/DatasetService.java @@ -165,16 +165,7 @@ public String createDataset( String description = "Create dataset " + datasetRequest.getName(); UUID defaultProfileId = datasetRequest.getDefaultProfileId(); loggingMetrics.set(BardEventProperties.BILLING_PROFILE_ID_FIELD_NAME, defaultProfileId); - - // Locate billing profile in TDR or Rawls - // No auth check: Just a check if there is an entry in our db for this billing profile - boolean isTdrBillingProfile; - try { - profileService.getProfileByIdNoCheck(defaultProfileId); - isTdrBillingProfile = true; - } catch (ProfileNotFoundException ex) { - isTdrBillingProfile = false; - } + boolean isTdrBillingProfile = profileService.isTdrBillingProfile(defaultProfileId); return jobService .newJob(description, DatasetCreateFlight.class, datasetRequest, userReq) diff --git a/src/main/java/bio/terra/service/profile/ProfileService.java b/src/main/java/bio/terra/service/profile/ProfileService.java index 6569634668..c24b430287 100644 --- a/src/main/java/bio/terra/service/profile/ProfileService.java +++ b/src/main/java/bio/terra/service/profile/ProfileService.java @@ -199,6 +199,19 @@ public BillingProfileModel getProfileByIdNoCheck(UUID id) { return profileDao.getBillingProfileById(id); } + public boolean isTdrBillingProfile(UUID defaultProfileId) { + // Locate billing profile in TDR or Rawls + // No auth check: Just a check if there is an entry in our db for this billing profile + boolean isTdrBillingProfile; + try { + getProfileByIdNoCheck(defaultProfileId); + isTdrBillingProfile = true; + } catch (ProfileNotFoundException ex) { + isTdrBillingProfile = false; + } + return isTdrBillingProfile; + } + // The idea is to use this call from create snapshot and create asset to validate that the // billing account is usable by the calling user diff --git a/src/main/java/bio/terra/service/snapshot/SnapshotService.java b/src/main/java/bio/terra/service/snapshot/SnapshotService.java index 289063023a..d88a848427 100644 --- a/src/main/java/bio/terra/service/snapshot/SnapshotService.java +++ b/src/main/java/bio/terra/service/snapshot/SnapshotService.java @@ -82,6 +82,7 @@ import bio.terra.service.filedata.google.firestore.FireStoreDependencyDao; import bio.terra.service.job.JobMapKeys; import bio.terra.service.job.JobService; +import bio.terra.service.profile.ProfileService; import bio.terra.service.rawls.RawlsService; import bio.terra.service.resourcemanagement.MetadataDataAccessUtils; import bio.terra.service.snapshot.exception.AssetNotFoundException; @@ -144,6 +145,7 @@ public class SnapshotService { private final RawlsService rawlsService; private final DuosClient duosClient; private final SnapshotBuilderSettingsDao snapshotBuilderSettingsDao; + private final ProfileService profileService; public SnapshotService( JobService jobService, @@ -159,7 +161,8 @@ public SnapshotService( AzureSynapsePdao azureSynapsePdao, RawlsService rawlsService, DuosClient duosClient, - SnapshotBuilderSettingsDao snapshotBuilderSettingsDao) { + SnapshotBuilderSettingsDao snapshotBuilderSettingsDao, + ProfileService profileService) { this.jobService = jobService; this.datasetService = datasetService; this.dependencyDao = dependencyDao; @@ -174,6 +177,7 @@ public SnapshotService( this.rawlsService = rawlsService; this.duosClient = duosClient; this.snapshotBuilderSettingsDao = snapshotBuilderSettingsDao; + this.profileService = profileService; } public String getSnapshotName(SnapshotRequestModel model) { @@ -248,11 +252,15 @@ public String createSnapshot( .map(model -> model.datasetName(dataset.getName())) .toList()); + boolean isTdrBillingProfile = + profileService.isTdrBillingProfile(snapshotRequestModel.getProfileId()); + return jobService .newJob(description, SnapshotCreateFlight.class, snapshotRequestModel, userReq) .addParameter(CommonMapKeys.CREATED_AT, Instant.now().toEpochMilli()) .addParameter(JobMapKeys.DATASET_ID.getKeyName(), dataset.getId()) .addParameter(JobMapKeys.SNAPSHOT_ID.getKeyName(), snapshotId.toString()) + .addParameter(JobMapKeys.TDR_BILLING_PROFILE_FALLBACK.getKeyName(), isTdrBillingProfile) .submit(); } diff --git a/src/main/java/bio/terra/service/snapshot/flight/create/SnapshotCreateFlight.java b/src/main/java/bio/terra/service/snapshot/flight/create/SnapshotCreateFlight.java index eeb6668a44..866a0b4671 100644 --- a/src/main/java/bio/terra/service/snapshot/flight/create/SnapshotCreateFlight.java +++ b/src/main/java/bio/terra/service/snapshot/flight/create/SnapshotCreateFlight.java @@ -34,8 +34,10 @@ import bio.terra.service.policy.PolicyService; import bio.terra.service.profile.ProfileService; import bio.terra.service.profile.flight.AuthorizeBillingProfileUseStep; +import bio.terra.service.profile.flight.AuthorizeRawlsBillingProjectsUseStep; import bio.terra.service.profile.flight.VerifyBillingAccountAccessStep; import bio.terra.service.profile.google.GoogleBillingService; +import bio.terra.service.rawls.RawlsService; import bio.terra.service.resourcemanagement.BufferService; import bio.terra.service.resourcemanagement.ResourceService; import bio.terra.service.resourcemanagement.azure.AzureAuthService; @@ -109,6 +111,7 @@ public SnapshotCreateFlight(FlightMap inputParameters, Object applicationContext SnapshotBuilderService snapshotBuilderService = appContext.getBean(SnapshotBuilderService.class); IamService iamService = appContext.getBean(IamService.class); + RawlsService rawlsService = appContext.getBean(RawlsService.class); SnapshotRequestModel snapshotReq = inputParameters.get(JobMapKeys.REQUEST.getKeyName(), SnapshotRequestModel.class); @@ -140,13 +143,25 @@ public SnapshotCreateFlight(FlightMap inputParameters, Object applicationContext var platform = CloudPlatformWrapper.of(sourceDataset.getDatasetSummary().getStorageCloudPlatform()); + boolean isTdrBillingProfile = + inputParameters.get(JobMapKeys.TDR_BILLING_PROFILE_FALLBACK.getKeyName(), Boolean.class); + // Take out a shared lock on the source dataset, to guard against it being deleted out from // under us (for example) addStep(new LockDatasetStep(datasetService, datasetId, true), randomBackoffRetry); - // Make sure this user is authorized to use the billing profile in SAM - addStep( - new AuthorizeBillingProfileUseStep(profileService, snapshotReq.getProfileId(), userReq)); + if (isTdrBillingProfile) { + // If using TDR billing profile, make sure this user is authorized to use the billing profile + // in Sam + addStep( + new AuthorizeBillingProfileUseStep(profileService, snapshotReq.getProfileId(), userReq)); + } else { + // If using Rawls billing project, make sure this user is authorized to use the billing + // project in Sam + addStep( + new AuthorizeRawlsBillingProjectsUseStep( + rawlsService, snapshotReq.getProfileId(), userReq)); + } if (platform.isGcp()) { addStep(new VerifyBillingAccountAccessStep(googleBillingService)); diff --git a/src/test/java/bio/terra/service/dataset/DatasetServiceUnitTest.java b/src/test/java/bio/terra/service/dataset/DatasetServiceUnitTest.java index 6a8c263103..a385f786b8 100644 --- a/src/test/java/bio/terra/service/dataset/DatasetServiceUnitTest.java +++ b/src/test/java/bio/terra/service/dataset/DatasetServiceUnitTest.java @@ -54,7 +54,6 @@ import bio.terra.service.load.LoadService; import bio.terra.service.profile.ProfileDao; import bio.terra.service.profile.ProfileService; -import bio.terra.service.profile.exception.ProfileNotFoundException; import bio.terra.service.resourcemanagement.MetadataDataAccessUtils; import bio.terra.service.resourcemanagement.ResourceService; import bio.terra.service.tabulardata.azure.StorageTableService; @@ -486,17 +485,12 @@ void createDataset(boolean isTDRBillingProfile) { datasetRequestModel, TEST_USER)) .thenReturn(jobBuilder); - - if (!isTDRBillingProfile) { - when(profileService.getProfileByIdNoCheck(defaultBillingProfile)) - .thenThrow(new ProfileNotFoundException("Profile not found")); - } + when(profileService.isTdrBillingProfile(defaultBillingProfile)).thenReturn(isTDRBillingProfile); ArgumentCaptor captor = ArgumentCaptor.forClass(FlightMap.class); when(jobService.submit(eq(DatasetCreateFlight.class), captor.capture())).thenReturn("JobId"); datasetService.createDataset(datasetRequestModel, TEST_USER); - verify(profileService).getProfileByIdNoCheck(defaultBillingProfile); FlightMap flightMap = captor.getValue(); assertThat( diff --git a/src/test/java/bio/terra/service/profile/ProfileServiceUnitTest.java b/src/test/java/bio/terra/service/profile/ProfileServiceUnitTest.java index 8d5f538de6..799a199d7f 100644 --- a/src/test/java/bio/terra/service/profile/ProfileServiceUnitTest.java +++ b/src/test/java/bio/terra/service/profile/ProfileServiceUnitTest.java @@ -1,7 +1,9 @@ package bio.terra.service.profile; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -24,6 +26,7 @@ import bio.terra.service.job.JobMapKeys; import bio.terra.service.job.JobService; import bio.terra.service.profile.azure.AzureAuthzService; +import bio.terra.service.profile.exception.ProfileNotFoundException; import bio.terra.service.profile.flight.ProfileMapKeys; import bio.terra.service.profile.flight.create.ProfileCreateFlight; import bio.terra.service.profile.flight.delete.ProfileDeleteFlight; @@ -191,4 +194,17 @@ void getProfileResources() { PROFILE_ID.toString(), IamAction.READ_SPEND_REPORT); } + + @Test + void isTdrBillingProfileIdTrue() { + when(profileDao.getBillingProfileById(PROFILE_ID)).thenReturn(new BillingProfileModel()); + assertTrue(profileService.isTdrBillingProfile(PROFILE_ID)); + } + + @Test + void isTdrBillingProfileIdFalse() { + when(profileDao.getBillingProfileById(PROFILE_ID)) + .thenThrow(new ProfileNotFoundException("Profile not found")); + assertFalse(profileService.isTdrBillingProfile(PROFILE_ID)); + } } diff --git a/src/test/java/bio/terra/service/snapshot/SnapshotServiceTest.java b/src/test/java/bio/terra/service/snapshot/SnapshotServiceTest.java index b49d6a925d..10749f32da 100644 --- a/src/test/java/bio/terra/service/snapshot/SnapshotServiceTest.java +++ b/src/test/java/bio/terra/service/snapshot/SnapshotServiceTest.java @@ -86,6 +86,7 @@ import bio.terra.service.job.JobBuilder; import bio.terra.service.job.JobMapKeys; import bio.terra.service.job.JobService; +import bio.terra.service.profile.ProfileService; import bio.terra.service.rawls.RawlsService; import bio.terra.service.resourcemanagement.MetadataDataAccessUtils; import bio.terra.service.resourcemanagement.google.GoogleProjectResource; @@ -115,6 +116,7 @@ import java.util.Map; import java.util.Set; import java.util.UUID; +import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -165,6 +167,7 @@ class SnapshotServiceTest { @Mock private RawlsService rawlsService; @Mock private DuosClient duosClient; @Mock private SnapshotBuilderSettingsDao settingsDao; + @Mock private ProfileService profileService; private final UUID snapshotId = UUID.randomUUID(); private final UUID datasetId = UUID.randomUUID(); private final UUID snapshotTableId = UUID.randomUUID(); @@ -192,7 +195,8 @@ void beforeEach() { azureSynapsePdao, rawlsService, duosClient, - settingsDao); + settingsDao, + profileService); } @Test @@ -1085,42 +1089,15 @@ private SnapshotRequestModel getDuosSnapshotRequestModel(String duosId) { @Test void testCreateSnapshotWithoutDuosDataset() { SnapshotRequestModel request = getDuosSnapshotRequestModel(null); - JobBuilder jobBuilder = mock(JobBuilder.class); - when(jobService.newJob(anyString(), eq(SnapshotCreateFlight.class), eq(request), eq(TEST_USER))) - .thenReturn(jobBuilder); - when(jobBuilder.addParameter(any(), any())).thenReturn(jobBuilder); - String jobId = String.valueOf(UUID.randomUUID()); - when(jobBuilder.submit()).thenReturn(jobId); - - String result = - service.createSnapshot( - request, service.getSourceDatasetFromSnapshotRequest(request), TEST_USER); - assertThat("Job is submitted and id returned", result, equalTo(jobId)); + mockCreateSnapshot(request); verify(duosClient, never()).getDataset(DUOS_ID, TEST_USER); - verify(jobBuilder).submit(); } @Test void testCreateSnapshotWithDuosDataset() { SnapshotRequestModel request = getDuosSnapshotRequestModel(DUOS_ID); - JobBuilder jobBuilder = mock(JobBuilder.class); - String jobId = mockJobService(request, jobBuilder); - - String result = - service.createSnapshot( - request, service.getSourceDatasetFromSnapshotRequest(request), TEST_USER); - assertThat("Job is submitted and id returned", result, equalTo(jobId)); + mockCreateSnapshot(request); verify(duosClient).getDataset(DUOS_ID, TEST_USER); - verify(jobBuilder).submit(); - } - - private String mockJobService(SnapshotRequestModel request, JobBuilder jobBuilder) { - when(jobService.newJob(anyString(), eq(SnapshotCreateFlight.class), eq(request), eq(TEST_USER))) - .thenReturn(jobBuilder); - when(jobBuilder.addParameter(any(), any())).thenReturn(jobBuilder); - String jobId = String.valueOf(UUID.randomUUID()); - when(jobBuilder.submit()).thenReturn(jobId); - return jobId; } @Test @@ -1144,19 +1121,13 @@ void testCreateSnapshotWithByRequestId() { makeByRequestIdContentsModel(snapshotAccessRequestId); SnapshotRequestModel request = new SnapshotRequestModel().contents(List.of(contentsModel)); request.profileId(UUID.randomUUID()); - JobBuilder jobBuilder = mock(JobBuilder.class); - String jobId = mockJobService(request, jobBuilder); when(snapshotRequestDao.getById(snapshotAccessRequestId)).thenReturn(snapshotAccessRequest); when(snapshotDao.retrieveSnapshot(snapshotAccessRequest.sourceSnapshotId())) .thenReturn( new Snapshot() .snapshotSources( List.of(new SnapshotSource().dataset(new Dataset().id(UUID.randomUUID()))))); - - String result = - service.createSnapshot( - request, service.getSourceDatasetFromSnapshotRequest(request), TEST_USER); - assertThat("Job is submitted and id returned", result, equalTo(jobId)); + mockCreateSnapshot(request); } private void mockSnapshotWithDuosDataset() { @@ -1201,7 +1172,6 @@ void testCreateSnapshotWithAuthDomainAndInheritStewardDisabled() { .profileId(UUID.randomUUID()) .addDataAccessControlGroupsItem("AuthDomain1") .contents(List.of(new SnapshotRequestContentsModel().datasetName(sourceDatasetName))); - mockCreateSnapshot(request); } @@ -1219,17 +1189,60 @@ void testCreateSnapshotInheritStewardEnabled() { .name("TestSnapshot") .profileId(UUID.randomUUID()) .contents(List.of(new SnapshotRequestContentsModel().datasetName(sourceDatasetName))); - mockCreateSnapshot(request); } - private void mockCreateSnapshot(SnapshotRequestModel request) { - JobBuilder jobBuilder = mock(JobBuilder.class); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void createSnapshot(boolean isTDRBillingProfile) { + UUID profileId = UUID.randomUUID(); + UUID datasetId = UUID.randomUUID(); + String jobId = UUID.randomUUID().toString(); + SnapshotRequestModel snapshotRequestModel = + new SnapshotRequestModel() + .name("TestSnapshot") + .profileId(profileId) + .contents(List.of(new SnapshotRequestContentsModel().datasetName("TestSourceDataset"))); + Dataset sourceDataset = new Dataset(new DatasetSummary().id(datasetId)); + + when(profileService.isTdrBillingProfile(profileId)).thenReturn(isTDRBillingProfile); + ArgumentCaptor captor = mockJobServiceWithCaptor(snapshotRequestModel, jobId); + + assertThat( + service.createSnapshot(snapshotRequestModel, sourceDataset, TEST_USER), equalTo(jobId)); + FlightMap flightMap = captor.getValue(); + assertThat(flightMap.get(JobMapKeys.DATASET_ID.getKeyName(), UUID.class), equalTo(datasetId)); + assertThat( + flightMap.get(JobMapKeys.TDR_BILLING_PROFILE_FALLBACK.getKeyName(), Boolean.class), + equalTo(isTDRBillingProfile)); + } + + @NotNull + private ArgumentCaptor mockJobServiceWithCaptor( + SnapshotRequestModel snapshotRequestModel, String jobId) { + JobBuilder jobBuilder = + new JobBuilder("", SnapshotCreateFlight.class, snapshotRequestModel, TEST_USER, jobService); + when(jobService.newJob( + anyString(), eq(SnapshotCreateFlight.class), eq(snapshotRequestModel), eq(TEST_USER))) + .thenReturn(jobBuilder); + ArgumentCaptor captor = ArgumentCaptor.forClass(FlightMap.class); + when(jobService.submit(eq(SnapshotCreateFlight.class), captor.capture())).thenReturn(jobId); + return captor; + } + + private String mockJobService(SnapshotRequestModel request, JobBuilder jobBuilder) { when(jobService.newJob(anyString(), eq(SnapshotCreateFlight.class), eq(request), eq(TEST_USER))) .thenReturn(jobBuilder); when(jobBuilder.addParameter(any(), any())).thenReturn(jobBuilder); String jobId = String.valueOf(UUID.randomUUID()); when(jobBuilder.submit()).thenReturn(jobId); + return jobId; + } + + private void mockCreateSnapshot(SnapshotRequestModel request) { + JobBuilder jobBuilder = mock(JobBuilder.class); + String jobId = mockJobService(request, jobBuilder); + when(profileService.isTdrBillingProfile(request.getProfileId())).thenReturn(false); String result = service.createSnapshot( diff --git a/src/test/java/bio/terra/service/snapshot/flight/create/SnapshotCreateFlightTest.java b/src/test/java/bio/terra/service/snapshot/flight/create/SnapshotCreateFlightTest.java index c01dd3e056..a0967ce4ad 100644 --- a/src/test/java/bio/terra/service/snapshot/flight/create/SnapshotCreateFlightTest.java +++ b/src/test/java/bio/terra/service/snapshot/flight/create/SnapshotCreateFlightTest.java @@ -29,6 +29,8 @@ import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.context.ApplicationContext; @@ -65,6 +67,7 @@ void beforeEach() { inputParameters.put(JobMapKeys.DATASET_ID.getKeyName(), datasetId); inputParameters.put(JobMapKeys.SNAPSHOT_ID.getKeyName(), UUID.randomUUID()); + inputParameters.put(JobMapKeys.TDR_BILLING_PROFILE_FALLBACK.getKeyName(), true); } @Test @@ -157,4 +160,28 @@ void testSnapshotCreateFlightByRequestId() { "AddCreatedInfoToSnapshotRequestStep", "NotifyUserOfSnapshotCreationStep")); } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testSnapshotCreateFlightTDRBillingProfile(boolean tdrBillingProfileFallback) { + SnapshotRequestModel request = + new SnapshotRequestModel() + .addContentsItem( + new SnapshotRequestContentsModel() + .mode(SnapshotRequestContentsModel.ModeEnum.BYFULLVIEW)); + inputParameters.put(JobMapKeys.REQUEST.getKeyName(), request); + inputParameters.put( + JobMapKeys.TDR_BILLING_PROFILE_FALLBACK.getKeyName(), tdrBillingProfileFallback); + var flight = new SnapshotCreateFlight(inputParameters, context); + + if (tdrBillingProfileFallback) { + assertThat( + FlightTestUtils.getStepNames(flight), + containsInRelativeOrder("AuthorizeBillingProfileUseStep")); + } else { + assertThat( + FlightTestUtils.getStepNames(flight), + containsInRelativeOrder("AuthorizeRawlsBillingProjectsUseStep")); + } + } }